【问题标题】:Keras MNIST target vector automatically converted to one-hot?Keras MNIST 目标向量自动转换为 one-hot?
【发布时间】:2020-10-10 22:30:03
【问题描述】:

当我从 Keras 加载 mnist 数据集时,我得到 4 个变量 -

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train 的形状是(60000, 28, 28),这很有意义,因为它包含 60,000 张 28x28 的图片

y_train 的形状只是(60000,),这表明它是一个包含数字目标标签 (0-9) 的一维向量。

为了运行数字分类,神经网络通常会输出一个 one-hot 编码向量,该向量将具有 10 个维度。我想我需要使用 to_categorical 将 y 目标从数值转换为分类,以便神经网络的形状输出与训练样本匹配,大概是 (60000, 10)

但在我在网上找到的一些示例中,to_categorical 从未用于重塑训练向量。 y_train.shape 仍然是 (60000,) 而神经网络的输出层是

 model.add(Dense(10, activation="softmax"))

输出一个 10-D one-hot 向量。

然后他们只是在y_train 上训练模型而没有问题

model.fit(x_train, y_train, epochs=2, validation_data=(x_test, y_test))

这怎么可能?形状为(60000, 10) 的神经网络输出不会与(60000,) 不兼容吗?还是 Keras 会自动将分类输出转换为数字?

编辑: 更清楚地说,我知道如何对其进行一次性编码,但我的问题是他们为什么不这样做。在示例中,网络在没有对目标类进行 one-hot 编码的情况下工作,而网络的输出显然是 one-hot 编码的。

编辑: Roshin 是对的。这只是使用 sparse_crossentropy 损失的效果,而不是分类损失。

【问题讨论】:

    标签: python tensorflow keras neural-network


    【解决方案1】:

    将损失函数改为

    loss = 'sparse_categorical_crossentropy'
    

    这将起作用,您不必更改输入数据形状

    【讨论】:

    • 你是对的。但为什么?分类和稀疏之间有什么区别?我知道之前已经回答过了,但为什么会影响数据形状?
    • 我想这个问题会解答你的疑惑:stackoverflow.com/questions/44674847/…
    【解决方案2】:

    您可以通过执行以下代码行自己将其转换为 one-hot:

    (x_train, l_train), (x_test, l_test) = mnist.load_data()
    y_train = np.zeros((l_train.shape[0], l_train.max()+1), dtype=np.float32)
    y_train[np.arange(l_train.shape[0]), l_train] = 1
    y_test = np.zeros((l_test.shape[0], l_test.max()+1), dtype=np.float32)
    y_test[np.arange(l_test.shape[0]), l_test] = 1
    

    【讨论】:

    • 我知道如何对其进行一次热编码,但我的问题是他们为什么不这样做。在示例中,网络在没有对目标类进行 one-hot 编码的情况下工作,而网络的输出显然是 one-hot 编码
    • 可能是他们使用 sparse_categorical_crossentropy 而不是 categorical_crossentropy 训练了模型。如果您的目标标签是要与 softmax one-hot 输出进行比较的整数,则使用 sparse_categorical_crossentropy
    猜你喜欢
    • 2017-09-19
    • 1970-01-01
    • 2019-04-21
    • 1970-01-01
    • 2017-09-22
    • 1970-01-01
    • 2017-07-27
    • 2021-12-05
    • 2021-08-05
    相关资源
    最近更新 更多