【问题标题】:Is there a way to train a CNN model, save the weights of this CNN, and then use this weights to retrain this CNN for other train data?有没有办法训练一个 CNN 模型,保存这个 CNN 的权重,然后使用这个权重重新训练这个 CNN 以获得其他训练数据?
【发布时间】:2020-03-15 20:03:49
【问题描述】:

有人告诉我这种类型的实验。 第一步是训练一个 CNN 并保持权重,第二步是使用这些权重来重新训练这个 CNN,但这一次是向你的训练集添加更多数据(微调)。

我想这有点像迁移学习,但需要训练 CNN。 有没有办法在训练 CNN 之前选择权重并将这些选择的权重归档?

所以到目前为止我所做的是训练我的 CNN 模型并将权重保存到 h5 文件中,代码如下

model.compile(loss='categorical_crossentropy', optimizer=opt,metrics=['accuracy'])
validation_data=(x_testcnn, y_test))
checkpoint_path= 'scratchmodel.best.h5'
save_dir = os.path.join(os.getcwd(), 'weights')
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                             save_weights_only=True,
                                             verbose=1)
 cnnhistory=model.fit(x_traincnn, 
      y_train,
      batch_size=16,           
      epochs=400,
      validation_data=(x_testcnn,y_test),
      callbacks=[cp_callback])

现在我想用相同的权重重新训练相同的 CNN,但这次将数据添加到训练集中。 有没有办法做到这一点? 感谢您的帮助。

【问题讨论】:

  • 请说明您到目前为止所尝试的内容,以及您遇到的具体问题。请不要将整个问题转嫁给其他人来解决。
  • 谢谢,我尽量说清楚了
  • 在文档中找不到这些信息吗?
  • 我还没有找到任何关于它的信息

标签: python conv-neural-network


【解决方案1】:

是的,您只需将权重加载到新创建的模型中,然后使用新数据进行训练。

from tensorflow.python.keras.models import load_model #Tensorflow 2.0

new_model.compile(loss='categorical_crossentropy', optimizer=opt,metrics=['accuracy'])
new_model = load_model(filepath, compile=False) #compile=False allows you to load saved optimizer state

new_model.fit(...) # Fit on new data, leveraging training on old data

【讨论】:

  • 非常感谢。这真的很有帮助。我想问另一个关于准确性的问题。你能用下面的代码告诉我: plt.plot(cnnhistory.history['accuracy']) plt.plot(cnnhistory.history['val_accuracy']) plt.title('model accuracy') plt.ylabel('acc ') plt.xlabel('epoch') 我在绘制哪个模型的准确度?非常感谢!
  • 很高兴您喜欢这个答案,请点赞或选择作为答案。看起来您正在绘制训练和验证准确性。验证准确度 (val_accuracy) 可以更好地衡量真实世界的性能,因为您的模型在训练期间没有看到该数据。
猜你喜欢
  • 2018-11-14
  • 2019-09-02
  • 2017-12-12
  • 1970-01-01
  • 2019-02-05
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多