【问题标题】:How to view the image from test generator to see if the prediction is correct or not如何从测试生成器查看图像以查看预测是否正确
【发布时间】:2021-11-05 14:04:38
【问题描述】:

我正在训练一个水果分类模型。截至目前,我的课程是: ['新鲜苹果'、'新鲜香蕉'、'新鲜橙子']

我正在使用 ImageDataGenerator 和 flow_from_directory 的训练、验证和测试生成器。我已经训练了模型,现在想将测试生成器输入模型以查看模型的性能。现在我在测试生成器中只有 2 张图像。我有以下代码进行预测:

    predictions = tuned_model.predict(test_generator)
    score = tf.nn.softmax(predictions[0])

    print(
        'This image most likely belongs to {} with a {:.2f} percent 
    confidence.'.format(
            class_names[np.argmax(score)], 100 * np.max(score)
        )
    )

结果如下:

    This image most likely belongs to Fresh Apples with a 46.19 percent confidence.

是的,准确率很低,我只训练了 10 个 epoch,哈哈。但是,有没有办法可以看到正在测试的图像?或者知道这个预测是否正确的方法?

编辑:

包括生成器的代码...

generator = ImageDataGenerator(
    rotation_range=45,
    rescale=1./255,
    horizontal_flip=True,
    vertical_flip=True,
    validation_split=.2
)

train_generator = generator.flow_from_directory(
    train_path,
    target_size=(im_height, im_width),
    batch_size = batch_size,
    subset='training'
)

validation_generator = generator.flow_from_directory(
    train_path,
    target_size=(im_height, im_width),
    batch_size=batch_size,
    subset='validation'
)

test_generator = generator.flow_from_directory(
    test_path,
    target_size= (im_height, im_width),
    batch_size= batch_size,
)

就我的班级标签而言,到目前为止,我只是对它们进行了硬编码

class_names = ['Fresh Apples', 'Fresh Bananas', 'Fresh Bananas']

我知道我可能应该导入 os 并根据文件结构创建标签,但除非我绝对需要,否则我稍后会这样做。

【问题讨论】:

  • 请包含更多代码来定义您的测试集标签以及您是否已将它们包含在测试生成器中。这样会更容易提供帮助,谢谢!
  • @TCArlen 我更新了我的 train、val 和 test 生成器的代码

标签: tensorflow keras conv-neural-network image-classification


【解决方案1】:

我假设您在创建测试生成器时在 flow_from_directory 中设置了 shuffle=False。然后使用

files=test_generator.filenames
class_dict=test_generator.class_indices # a dictionary of the form class name: class index
rev_dict={}
for key, value in class_dict.items()
    rev_dict[value]=key   # dictionary of the form class index: class name

files 是文件名列表,按文件呈现以供预测的顺序排列。 然后做

predictions = tuned_model.predict(test_generator)

然后遍历预测

for i, p in enumerate(predictions)
    index=np.argmax(p)
    klass=rev_dict[index]    
    prob=p[index]
    print('for file ', files[i], ' predicted class is ', klass,' with probability ',prob)

当然你也可以显示图片

【讨论】:

  • 谢谢,这正是我需要的
  • 欢迎,希望对你有用
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-09-08
  • 1970-01-01
  • 2012-02-16
  • 2017-12-20
  • 2023-03-05
  • 2022-01-09
相关资源
最近更新 更多