【发布时间】:2018-10-15 17:34:33
【问题描述】:
如何在 lstm 中提前停止。
我使用的是 python tensorflow,但不是 keras。
如果您能提供一个示例 python 代码,我将不胜感激。
问候
【问题讨论】:
标签: python-3.x tensorflow lstm
如何在 lstm 中提前停止。
我使用的是 python tensorflow,但不是 keras。
如果您能提供一个示例 python 代码,我将不胜感激。
问候
【问题讨论】:
标签: python-3.x tensorflow lstm
您可以使用checkpoints:
from keras.callbacks import EarlyStopping
earlyStop=EarlyStopping(monitor="val_loss",verbose=2,mode='min',patience=3)
history=model.fit(xTrain,yTrain,epochs=100,batch_size=10,validation_data=(xTest,yTest) ,verbose=2,callbacks=[earlyStop])
即使经过 3 个 epoch(patience=3),当 "val_loss" 没有减少时(mode='min') 训练将停止
#Didn't realize u were note using keras
【讨论】:
你可以通过一点搜索找到它 https://github.com/mmuratarat/handson-ml/blob/master/11_deep_learning.ipynb
max_checks_without_progress = 20
checks_without_progress = 0
best_loss = np.infty
....
if loss_val < best_loss:
save_path = saver.save(sess, './my_mnist_model.ckpt')
best_loss = loss_val
check_without_progress = 0
else:
check_without_progress +=1
if check_without_progress > max_checks_without_progress:
print("Early stopping!")
break
print("Epoch: {:d} - ".format(epoch), \
"Training Loss: {:.5f}, ".format(loss_train), \
"Training Accuracy: {:.2f}%, ".format(accuracy_train*100), \
"Validation Loss: {:.4f}, ".format(loss_val), \
"Best Loss: {:.4f}, ".format(best_loss), \
"Validation Accuracy: {:.2f}%".format(accuracy_val*100))
【讨论】: