【问题标题】:What are the actual class labels while using SparseCategoricalCrossEntropy loss for multiclass classification in keras?在 keras 中使用 SparseCategoricalCrossEntropy 损失进行多类分类时,实际的类标签是什么?
【发布时间】:2021-10-20 07:02:13
【问题描述】:

我正在尝试使用:[SparseCategoricalCrossEntropy][https://www.tensorflow.org/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy] 进行多类分类

这将给我最后一个维度作为类的数量 (N_CLASSES)。但我想从预测中检索实际的类标签。

基本上,如果我有 5 个类 (N_CLASSES=5),那么我有 5 列,每列包含该类的概率。但我不知道哪一列属于哪个实际标签。如何检索实际的类标签?

例如,如果我的实际类标签为 [1.03, 2.07, -2.09, -974, 366],那么从形状 (None, 5) 的输出中,我如何知道哪一列代表哪个类?

注意:由于内存问题,我无法使用 CategoricalCrossEntropy 并传入 one-hot 编码的实际目标表示。

任何帮助将不胜感激

【问题讨论】:

    标签: python tensorflow keras tensorflow2.0


    【解决方案1】:

    其实很简单。假设您的模型输出predictions = [1.03, 2.07, -2.09, -974, 366]。这 5 个数字代表您的模型对您的输入数据对应于 5 个不同类别中的每一个的信心。如果您随后将np.argmax 应用于您的预测,则返回predictions 中最大值的索引:

    np.argmax(predictions)

    您将获得索引 4。假设您的数据集中的每个标签都是 0 到 4 之间的整数,并且由于您使用的是SparseCategoricalCrossEntropy,您可以说您的模型最有信心您的输入数据属于类4(无论第 4 类是什么)。我希望你能明白。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-06-26
      • 1970-01-01
      • 2021-07-17
      • 2017-10-25
      • 2019-08-16
      • 2020-10-26
      • 2018-08-04
      • 2017-06-18
      相关资源
      最近更新 更多