【问题标题】:How to load a trained MXnet model?如何加载经过训练的 MXnet 模型?
【发布时间】:2018-04-21 18:48:59
【问题描述】:

我已经使用 MXnet 训练了一个网络,但我不确定如何保存和加载参数以供以后使用。首先我定义和训练网络:

    dataIn = mx.sym.var('data')
    fc1 = mx.symbol.FullyConnected(data=dataIn, num_hidden=100)
    act1 = mx.sym.Activation(data=fc1, act_type="relu")
    fc2 = mx.symbol.FullyConnected(data=act1, num_hidden=50)
    act2 = mx.sym.Activation(data=fc2, act_type="relu")
    fc3 = mx.symbol.FullyConnected(data=act2, num_hidden=25)
    act3 = mx.sym.Activation(data=fc3, act_type="relu")
    fc4 = mx.symbol.FullyConnected(data=act3, num_hidden=10)
    act4 = mx.sym.Activation(data=fc4, act_type="relu")
    fc5 = mx.symbol.FullyConnected(data=act4, num_hidden=2)
    lenet = mx.sym.SoftmaxOutput(data=fc5, name='softmax',normalization = 'batch')


# create iterator around training and validation data
train_iter = mx.io.NDArrayIter(data=data[:ntrain], label = phen[:ntrain],batch_size=batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(data=data[ntrain:], label=phen[ntrain:], batch_size=batch_size)

# create a trainable module on GPU 0
lenet_model = mx.mod.Module(symbol=lenet, context=mx.gpu())
# train with the same
lenet_model.fit(train_iter,
                eval_data=val_iter,
                optimizer='adam',
                optimizer_params={'learning_rate':0.00001},
                eval_metric='f1',
                batch_end_callback = mx.callback.Speedometer(batch_size, 10),
                num_epoch=1000)

这个模型在测试集上表现很好,所以我想保留它。接下来,我保存网络布局和参数化:

lenet.save('./testNet_symbol.mxnet')
lenet_model.save_params('./testNet_module.mxnet')

我在加载网络时可以找到的所有文档似乎都在训练例程中实现了保存功能,以在每个时期结束时保存网络参数。我在训练过程中没有设置这些检查点其他方法使用mx.model.FeedForward类,这似乎不合适。还有其他方法从 .json 文件加载网络,由于我的保存功能,我没有该文件。训练完成后如何保存/加载网络?

【问题讨论】:

    标签: deep-learning mxnet


    【解决方案1】:

    您只需执行此操作即可保存:

    lenet_model.save_checkpoint('lenet', num_epoch, save_optimizer_states=True)
    

    如果 states 标志设置为 True,这将创建 3 个文件,否则 2 个文件:

    .params(权重), .json(符号), .states

    this 加载:

    lenet_model = mx.mod.Module.load(prefix,epoch)
    lenet_model.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
    

    【讨论】:

      猜你喜欢
      • 2017-07-28
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-01-11
      • 1970-01-01
      • 1970-01-01
      • 2015-10-09
      相关资源
      最近更新 更多