在 python caffe 接口中,caffe.Net 对象实例化了加载.prototxt 文件,该文件定义了网络架构。您可以使用具有以下属性的caffe.Net 对象来访问网络上的各种信息。
-
blob_loss_weights:按层名称索引的网络 blob 损失权重的 OrderedDict(从下到上,即输入到输出)
-
blobs:按层名称索引的网络 blob 的 OrderedDict(从下到上,即输入到输出)
-
bottom_names:网络中的所有底层名称
-
inputs:此网络的输入
-
layer_dict:按层名称索引的网络层的 OrderedDict(从下到上,即输入到输出)
-
layers:caffe._caffe.LayerVec - 其元素为网络中caffe.Layer 对象的列表,caffe.Layer 类有blobs 字段用于层的参数内存和type 用于层类型(例如、卷积、数据等)
-
outputs:来自该网络的输出
-
params:按名称索引的网络参数的 OrderedDict(从下到上,即输入到输出);每个都是多个 blob 的列表(例如,权重和偏差)
-
top_names:网络中的所有知名人士
您可以使用 caffe.Net.params 来访问图层的可学习参数,同时使用 caffe.Net.layer_dict 来访问图层信息。
caffe.Net.params 是有序字典,其中键是层名称,值是参数的 blob(例如,权重和偏差),在卷积层的情况下,blob 的第一个元素是权重和第二个Blob 的元素是偏差:
-
caffe.Net.params['layer_name'][0] : 重量
-
caffe.Net.params['layer_name'][1]:偏见
请注意,访问 blob 的内存应使用 caffe.Net.params['layer_name'][0].data 完成,更新 blob 的内存应使用 ... 完成,例如 caffe.Net.params['layer_name'][0].data[...]
以下代码说明了从 numpy 保存的文件 (.npy) 中加载可学习参数:
def load_weights_and_biases(network):
k_list = list(network.params.keys())
suffix = ["weight", "bias"]
num_layers = len(network.layer_dict)
for idx, layer_name in enumerate(network.layer_dict):
print(f"\n-----------------------------")
print(f"layer index: {idx}/{num_layers}")
print(f"layer name: '{layer_name}''")
print(f"layer type: '{detection_nw.layers[idx].type}' ")
if layer_name in k_list:
params = network.params[layer_name]
print(f"{len(params)} learnable parameters in '{detection_nw.layers[idx].type}' type")
for i, p in enumerate(params):
#print(f"\tparams[{i}]: {p}")
#print(f"\tparams[{i}] CxHxW: {p.channels}x{p.height}x{p.width}")
print(f"\tp[{i}]: {p.data.shape} of {p.data.dtype}")
param_file_path = f"./npy_save/{layer_name}_{suffix[i]}.npy"
param_file = Path(param_file_path)
if param_file.exists():
print(f"\tload {param_file_path}")
arr = np.load(param_file_path, allow_pickle=True)
if p.data.shape == arr.shape:
print(f"\tset {layer_name}_{suffix[i]} with arr:shape {arr.shape}, type {arr.dtype}")
p.data[...] = arr
else:
print(f"p.data.shape: {p.data.shape} is not equal to arr.shape: {arr.shape}")
break
else:
print(f"{param_file_path} is not exits!!")
break
else:
print(f"no learnable parameters in '{layer_name}' of '{network.layers[idx].type}' type'")
Blob 类型在 python caffe(又名 pycaffe)接口中定义为 caffe._caffe.Blob。在import caffe 之后使用help(caffe._caffe.Blob) 以及帮助输出的此处定义的数据描述符部分中描述的名称作为属性。
有关 Caffe 中 Blob 的更多详细信息参考