【发布时间】:2021-09-21 11:29:56
【问题描述】:
我正在使用此代码:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, LSTM, Input, Conv2D, Lambda
from tensorflow.keras import Model
def reshape_n(x):
x = tf.compat.v1.placeholder_with_default(
x,
[None, 121, 240, 2])
return x
input_shape = (121, 240, 1)
inputs = Input(shape=input_shape)
x = Conv2D(1, 1)(inputs)
x = LSTM(2, return_sequences=True)(x[0, :, :, :])
x = Lambda(reshape_n, (121, 240,2))(x[None, :, :, :])
x = Conv2D(1, 1)(x)
output = Dense(3, activation='softmax')(x)
model = Model(inputs, output)
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics='accuracy')
print(model.summary())
train_x = np.random.randint(0, 30, size=(10, 121, 240))
train_y = np.random.randint(0, 3, size=(10, 121, 240))
train_y = tf.one_hot(tf.cast(train_y, 'int32'), depth=3)
model.fit(train_x, train_y, epochs=2)
我收到:
logits and labels must be broadcastable: logits_size=[29040,3] labels_size=[290400,3]
如果我只是省略 LSTM 层:
x = Conv2D(1, 1)(inputs)
x = Conv2D(1, 1)(x)
output = Dense(3, activation='softmax')(x)
然后代码运行没有任何问题!
【问题讨论】:
-
您实际上已经省略了
LSTM和Lambda(reshape_n, ...)。你确定没关系? -
@CaptainTrojan:是的,没关系。只有使用 LSTM,Lambda 层才有意义
标签: deep-learning lstm tensorflow2.0 tf.keras