【问题标题】:What is x_train.reshape() and What it Does?什么是 x_train.reshape() 以及它的作用?
【发布时间】:2023-03-15 22:51:01
【问题描述】:

使用 MNIST 数据集

import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist

# MNIST dataset parameters
num_classes = 10 # total classes (0-9 digits)
num_features = 784 # data features (img shape: 28*28)

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Convert to float32
x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)

# Flatten images to 1-D vector of 784 features (28*28)
x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])

# Normalize images value from [0, 255] to [0, 1]
x_train, x_test = x_train / 255., x_test / 255.

在这些代码的 第 15 行,即,

x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])。我无法理解这些重塑在我们的数据集中的真正作用..??请解释一下。

【问题讨论】:

  • 你在整形前后打印x_trainx_test了吗?
  • 该行顶部的评论解释了它:# Flatten images to 1-D vector of 784 features (28*28)

标签: python machine-learning keras tensorflow2.0 mnist


【解决方案1】:

如第 14 行所述,它将图像展平为 784 个特征 (28*28) 的一维向量,即,它将维度为 28*28 的二维 NumPy 数组转换为长度为 784 的一维数组

【讨论】:

    最近更新 更多