【问题标题】:Reading network parameters from caffe .prototxt model definition in Python从 Python 中的 caffe .prototxt 模型定义中读取网络参数
【发布时间】:2017-09-07 15:09:58
【问题描述】:

我想从 Python 中 .prototxt 中定义的 caffe 网络中读取网络参数,因为 layer_dict 中的图层对象仅告诉我,例如它是一个“卷积”层,但不是.prototxt 文件中明确定义的kernel_sizestrides 等。

假设我有一个model.prototxt,就像这样:

name: "Model"
layer {
  name: "data"
  type: "Input"
  top: "data"
  input_param {
    shape: {
      dim: 64
      dim: 1
      dim: 28
      dim: 28
    }
  }
}
layer {
  name: "conv2d_1"
  type: "Convolution"
  bottom: "data"
  top: "conv2d_1"
  convolution_param {
    num_output: 32
    kernel_size: 3
    stride: 1
    weight_filler {
      type: "gaussian" # initialize the filters from a Gaussian
      std: 0.01        # distribution with stdev 0.01 (default mean: 0)
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}

layer {
  name: "dense_1"
  type: "InnerProduct"
  bottom: "conv2d_1"
  top: "out"
  inner_product_param {
    num_output: 1024
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}

我发现可以像这样解析模型:

from caffe.proto import caffe_pb2
import google.protobuf.text_format
net = caffe_pb2.NetParameter()
f = open('model.prototxt', 'r')
net = google.protobuf.text_format.Merge(str(f.read()), net)
f.close()

但我不知道如何从结果对象中获取 protobuf 消息中的字段。

【问题讨论】:

    标签: python protocol-buffers caffe pycaffe


    【解决方案1】:

    您可以遍历层并询问其对应的参数,例如:

    for i in range(0, len(net.layer)):
        if net.layer[i].type == 'Convolution':
            net.layer[i].convolution_param.bias_term = True # bias term, for example
    

    可以在caffe.proto中找到合适的*_param类型,例如:

    optional ConvolutionParameter convolution_param = 106
    

    【讨论】:

      【解决方案2】:

      Caffe prototxt 文件基于 Google Protobuf 构建。为了有问题地访问它们,您需要使用该软件包。这是一个示例脚本 (source):

      from caffe.proto import caffe_pb2
      import google.protobuf.text_format as txtf
      
      net = caffe_pb2.NetParameter()
      
      fn = '/tmp/net.prototxt'
      with open(fn) as f:
          s = f.read()
          txtf.Merge(s, net)
      
      net.name = 'my new net'
      layerNames = [l.name for l in net.layer]
      idx = layerNames.index('fc6')
      l = net.layer[idx]
      l.param[0].lr_mult = 1.3
      
      outFn = '/tmp/newNet.prototxt'
      print 'writing', outFn
      with open(outFn, 'w') as f:
          f.write(str(net))
      

      【讨论】: