【发布时间】:2018-12-26 22:21:55
【问题描述】:
我正在使用 TF Estimator 在数据集上训练我的模型。对于前几次训练迭代,我想冻结网络中的某些层。对于剩余的迭代,我想解冻这些层。
我发现了一些解决方案,我们在估算器的 model_fn 中有两个不同的优化器 train_ops。
def ModelFunction(features, labels, mode, params):
if mode == tf.estimator.ModeKeys.TRAIN:
layerTrainingVars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "LayerName")
#Train Op for freezing layers
freeze_train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step(), var_list=layerTrainingVars)
#Train Op for training all layers
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
#Based on whether we want to freeze or not, we send the corresponding train_op to the estimatorSpec. How do I do this?
estimatorSpec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=freeze_train_op)
return estimatorSpec
对于上述方案,可以根据train_op返回对应的EstimatorSpec。我尝试使用 freeze_train_op 进行一些训练迭代,然后终止进程,并将 train_op 更改为在代码中没有层冻结。执行此操作后,出现检查点错误,表示检查点中保存的图形/变量不同。我猜第一组迭代没有保存冻结层。如何以编程方式切换 train_ops 以使检查点也起作用?
有没有更好的方法来冻结/解冻层以在 TF.Estmator 中进行训练?
【问题讨论】:
标签: python tensorflow tensorflow-estimator