【发布时间】:2022-01-24 02:44:48
【问题描述】:
我有一个Train 和Validation 批处理数据集:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_path,
label_mode = 'categorical', #it is used for multiclass classification. It is one hot encoded labels for each class
validation_split = 0.2, #percentage of dataset to be considered for validation
subset = "training", #this subset is used for training
seed = 1337, # seed is set so that same results are reproduced
image_size = img_size, # shape of input images
batch_size = batch_size, # This should match with model batch size
)
valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_path,
label_mode ='categorical',
validation_split = 0.2,
subset = "validation", #this subset is used for validation
seed = 1337,
image_size = img_size,
batch_size = batch_size,
)
我试图显示 9 张图像以显示它们的外观,我设法做到了,但我似乎无法绘制它们各自的标签。
代码如下:
class_names = train_ds.class_names
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
显示这个:
如果我尝试通过添加来获取标签:plt.title(class_names[labels[i]])
我收到以下错误:TypeError: only integer scalar arrays can be converted to a scalar index
我已经尝试过其他帖子的解决方案,例如以下plt.title(class_names[labels[i][0]]),但没有任何成功。
当我打印标签[i]时,我得到标签的一个热编码......也许这就是为什么?
结果代码:
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[np.argmax(labels[i], axis=None, out=None)])
plt.axis("off")
【问题讨论】:
标签: python tensorflow keras deep-learning