【发布时间】:2017-05-16 11:49:14
【问题描述】:
我正在尝试在 Tensorflow 工作流程中冻结 Keras 层。这就是我定义图表的方式:
import tensorflow as tf
from keras.layers import Dropout, Dense, Embedding, Flatten
from keras import backend as K
from keras.objectives import binary_crossentropy
import tensorflow as tf
sess = tf.Session()
from keras import backend as K
K.set_session(sess)
labels = tf.placeholder(tf.float32, shape=(None, 1))
user_id_input = tf.placeholder(tf.float32, shape=(None, 1))
item_id_input = tf.placeholder(tf.float32, shape=(None, 1))
max_user_id = all_ratings['user_id'].max()
max_item_id = all_ratings['item_id'].max()
embedding_size = 30
user_embedding = Embedding(output_dim=embedding_size, input_dim=max_user_id+1,
input_length=1, name='user_embedding', trainable=all_trainable)(user_id_input)
item_embedding = Embedding(output_dim=embedding_size, input_dim=max_item_id+1,
input_length=1, name='item_embedding', trainable=all_trainable)(item_id_input)
user_vecs = Flatten()(user_embedding)
item_vecs = Flatten()(item_embedding)
input_vecs = concatenate([user_vecs, item_vecs])
x = Dense(30, activation='relu')(input_vecs)
x1 = Dropout(0.5)(x)
x2 = Dense(30, activation='relu')(x1)
y = Dense(1, activation='sigmoid')(x2)
loss = tf.reduce_mean(binary_crossentropy(labels, y))
train_step = tf.train.AdamOptimizer(0.004).minimize(loss)
然后我只训练模型:
with sess.as_default():
train_step.run(..)
当可训练标志设置为True 时,一切正常。然后当我将其设置为False 时,它不会冻结图层。
我还尝试使用train_step_freeze = tf.train.AdamOptimizer(0.004).minimize(loss, var_list=[user_embedding]) 仅在我想训练的变量上最小化,我得到:
('Trying to optimize unsupported type ', <tf.Tensor 'Placeholder_33:0' shape=(?, 1) dtype=float32>)
是否可以在 Tensorflow 中使用 Keras 层并冻结它们?
编辑
为了清楚起见,我想使用 Tensorflow 训练模型,而不是使用 model.fit()。在 Tensorflow 中执行此操作的方法似乎是将 var_list=[] 传递给 minimize() 方法。但是我在执行此操作时遇到错误:
('Trying to optimize unsupported type ', <tf.Tensor 'Placeholder_33:0' shape=(?, 1) dtype=float32>)
【问题讨论】:
-
我还想冻结一个 Keras 模型并使用 TensorFlow 训练剩余层。您找到解决问题的方法了吗?
-
不,我遇到了很多问题,最终决定使用 PyTorch。我建议你也这样做。很抱歉没有答案。
-
我找到了一种方法来冻结 Keras 模型并使用 TensorFlow 仅训练其他层。请查看我的回答,如果它回答了您的问题,请告诉我。
标签: tensorflow keras