【发布时间】:2019-09-03 01:25:00
【问题描述】:
我想在生成器中设置我的 LSTM 隐藏状态。但是,状态集只在生成器之外起作用:
K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # this works
def gen_data():
x = np.zeros((batch_size, num_steps, num_input))
y = np.zeros((batch_size, num_steps, num_output))
while True:
for i in range(batch_size):
K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # error
x[i, :, :] = X_train[gen_data.current_idx]
y[i, :, :] = Y_train[gen_data.current_idx]
gen_data.current_idx += 1
yield x, y
gen_data.current_idx = 0
在fit_generator函数中调用生成器:
model.fit_generator(gen_data(), len(X_train)//batch_size, 1, validation_data=None)
这是我打印状态时的结果:
print(model.layers[0].states[0])
<tf.Variable 'lstm/Variable:0' shape=(1, 2) dtype=float32>
这是生成器中出现的错误:
ValueError: Tensor("Placeholder_1:0", shape=(1, 2), dtype=float32) must be from the same graph as Tensor("lstm/Variable:0", shape=(), dtype=resource)
我做错了什么?
【问题讨论】:
-
根据 keras 文档,
fit_generator函数“用 Python 生成器(或序列实例)逐批生成的数据训练模型”,所以你需要一个 Python @ 987654321@这里。 -
@KacperFloriański 我没有写下我的完整生成器,但现在我已经编辑了它
标签: python tensorflow keras lstm