【发布时间】:2022-02-10 02:45:08
【问题描述】:
我正在尝试将 caffe 模型转换为 keras,我已经成功地能够同时使用 MMdnn 甚至 caffe-tensorflow。我的输出是.npy 文件和.pb 文件。我对.pb 文件不太满意,所以我坚持使用包含权重和偏差的.npy 文件。我重构了一个 mAlexNet 网络如下:
import tensorflow as tf
from tensorflow import keras
from keras.layers import Conv2D, MaxPool2D, Dropout, Dense, Flatten
def define_malexnet():
input = keras.Input(shape=(224, 224, 3), name='data')
x = Conv2D(16, kernel_size=(11,11), strides=(4,4), activation='relu', name='conv1')(input)
x = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same', name='pool1')(x)
x = Conv2D(20, kernel_size=(5,5), strides=(1,1), activation='relu', name='conv2')(x)
x = MaxPool2D(pool_size=(3,3), strides=(2,2), name='pool2')(x)
x = Conv2D(30, kernel_size=(3,3), strides=(1,1), activation='relu', name='conv3')(x)
x = MaxPool2D(pool_size=(3,3), strides=(2,2), name='pool3')(x)
x = Flatten()(x)
x = Dense(48, activation='relu', name='fc4')(x)
output = Dense(2, activation='softmax', name='fc5')(x)
occupancy_model = keras.Model(input, output, name='occupancy_malexnet')
occupancy_model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return occupancy_model
然后我尝试使用以下代码 sn-p 加载权重:
import numpy as np
weights_data = np.load('weights.npy', allow_pickle=True).item()
model = define_malexnet()
for layer in model.layers:
if layer.name in weights_data.keys():
layer_weights = weights_data[layer.name]
layer.set_weights((layer_weights['weights'], layer_weights['bias']))
在这个过程中我得到一个错误:
ValueError: Layer conv1 weight shape (16,) 不兼容 提供权重形状 (1, 1, 1, 16)。
现在据我了解,这是因为不同的后端以及它们如何初始化权重,但我还没有找到解决这个问题的方法。我的问题是,如何调整从文件加载的权重以适合我的 keras 模型?链接到weights.npy 文件https://drive.google.com/file/d/1QKzY-WxiUnf9VnlhWQS38DE3uF5I_qTl/view?usp=sharing。
【问题讨论】:
-
如果
layer_weights['weights']只是一个 1 x 1 的 conv,你能应用一个ravel()或flatten()到它,这样你就可以把它变成一维 NumPy 数组吗? -
不是,如果我尝试
ravel()offlatten(),第一个卷积层的权重形状会起作用:ValueError: Layer conv1 weight shape (11, 11, 3, 16) is not与提供的重量形状 (5808,) 兼容。 -
binary_crossentry 应该有密集的 1 个神经元
-
问题是你的偏差向量。它的形状为 4D 张量。把它弄平。除了消除偏差之外,我还下载了您的权重并重用了您的代码。它有效!
-
贴出权重的来源,以便测试加载
标签: python tensorflow keras caffe