【问题标题】:TensorFlow 2 model.predict() output has incorrect shape: input data shape, not label data shapeTensorFlow 2 model.predict() 输出的形状不正确:输入数据形状,而不是标签数据形状
【发布时间】:2022-08-16 22:08:58
【问题描述】:

总的来说,我是 TensorFlow 和 ML 的新手。

我正在尝试在 TensorFlow 2.9.1 (Python 3.9.12) 中构建一个简单的线性回归模型,该模型对每日天气数据块进行训练并预测特定特征。我将数据集分为训练集、验证集和测试集。我想绘制从test_inputs 集合预测的值,但linear.predict(test_inputs) 的输出具有test_inputs 的形状,而不是像我预期的test_labels

我正在使用的数据具有以下形状:

<data>.shape = (years, days, features)
train_inputs.shape = (91, 245, 6)
train_labels.shape = (91, 1, 1)
val_inputs.shape = (26, 245, 6)
val_labels.shape = (26, 1, 1)
test_inputs.shape = (13, 245, 6)
test_labels.shape = (13, 1, 1)

我构建和训练模型如下:

linear = tf.keras.Sequential([
    tf.keras.layers.Dense(units=1)
])

early_stopping = tf.keras.callbacks.EarlyStopping(monitor=\'val_loss\',
                                                    patience=2,
                                                    mode=\'min\')
MAX_EPOCHS = 1000
# Build model
linear.compile(loss=tf.losses.MeanSquaredError(),
                optimizer=tf.optimizers.Adam(),
                metrics=[tf.metrics.MeanAbsoluteError()])

# Train model
linear.fit(x=train_inputs, y=train_labels, epochs=MAX_EPOCHS,
                      validation_data=(val_inputs, val_labels),
                      callbacks=[early_stopping],
                      verbose=1)

# Evaluate model
linear.evaluate(x=test_inputs, y=test_labels)

然后我尝试通过以下方式从我的test_inputs 数据集中获取预测值:

predictions = linear(test_inputs)

我希望predictions.shape 给出(13, 1, 1),但它却给出了(13, 245, 1)。任何帮助将不胜感激。

    标签: python tensorflow machine-learning keras


    【解决方案1】:

    912613 似乎是数据集的批量大小。

    尝试扁平化您的输入 -

    linear = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(245, 6)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(1)
    ])
    

    【讨论】:

    • 谢谢你的建议。扁平化输入对我有什么作用?
    • 将您的数据从 2d 转换为 1d
    • 在这种情况下,为什么我需要我的输入数据是一维的?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-09-04
    • 1970-01-01
    • 1970-01-01
    • 2016-07-30
    • 2019-07-04
    • 1970-01-01
    相关资源
    最近更新 更多