【发布时间】:2020-03-07 19:45:05
【问题描述】:
我有一个用于图像分类的标准 CNN,使用以下生成器来获取数据集:
generator = validation_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
directory=val_dir,
shuffle=False,
target_size=(100,100),
class_mode='categorical')
我可以很容易地得到预测的标签:
predictions = model.predict(dataset)
现在我想获取所有预测的(原始)真实标签和图像,以与预测相同的顺序进行比较。我确信这些信息很容易存储在某个地方,但我一直找不到。
【问题讨论】:
-
真正的标签是什么意思?这个函数给你在训练集中使用的标签(通常是数字)!所以你知道每个标签背后的实际含义。
-
我的意思是数据集中该图像的原始标签,而不是模型的预测。
-
你能发布更多的代码而不是上面的一行吗?例如。您定义生成器的部分。这样我们就能更好地帮助您。
-
@Tinu 你是对的!我已经添加了生成器代码。
标签: python tensorflow keras conv-neural-network