【发布时间】:2021-04-11 13:49:58
【问题描述】:
我正在训练一个基于 RNN 的英语到印地语神经机器翻译模型。我有一个 LSTM 层和注意力层。我收到一个错误,(0) Invalid argument: logits and labels must be broadcastable: logits_size=[384,2971] labels_size=[864,2971]
我的模型总结是
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, None)] 0
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, None)] 0
__________________________________________________________________________________________________
embedding (Embedding) (None, None, 40) 93760 input_1[0][0]
__________________________________________________________________________________________________
embedding_1 (Embedding) (None, None, 40) 118840 input_2[0][0]
__________________________________________________________________________________________________
conv1d (Conv1D) (None, None, 16) 7056 embedding[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, None, 16) 10256 embedding_1[0][0]
__________________________________________________________________________________________________
lstm (LSTM) [(None, 40), (None, 9120 conv1d[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM) [(None, None, 40), ( 9120 conv1d_1[0][0]
lstm[0][1]
lstm[0][2]
__________________________________________________________________________________________________
dense (Dense) (None, None, 2971) 121811 lstm_1[0][0]
==================================================================================================
Total params: 369,963
Trainable params: 369,963
Non-trainable params: 0
__________________________________________________________________________________________________
模型编译代码为
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
train_samples = len(X_train)
val_samples = len(X_test)
batch_size = 32
epochs = 100
合身是
history = model.fit_generator(generator = generate_batch(X_train, y_train, batch_size = batch_size),
steps_per_epoch = train_samples//batch_size,
epochs=epochs,
validation_data = generate_batch(X_test, y_test, batch_size = batch_size),
validation_steps = val_samples//batch_size)
错误是
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:1844: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.
warnings.warn('`Model.fit_generator` is deprecated and '
Epoch 1/100
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-39-dc64566948be> in <module>()
3 epochs=epochs,
4 validation_data = generate_batch(X_test, y_test, batch_size = batch_size),
----> 5 validation_steps = val_samples//batch_size)
7 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: logits and labels must be broadcastable: logits_size=[384,2971] labels_size=[864,2971]
[[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at <ipython-input-39-dc64566948be>:5) ]]
[[gradient_tape/model_1/embedding_3/embedding_lookup/Reshape/_56]]
(1) Invalid argument: logits and labels must be broadcastable: logits_size=[384,2971] labels_size=[864,2971]
[[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at <ipython-input-39-dc64566948be>:5) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_11885]
Function call stack:
train_function -> train_function
【问题讨论】:
标签: python-3.x tensorflow keras conv-neural-network lstm