【发布时间】:2021-07-19 04:12:22
【问题描述】:
我正在通过关注this TensorFlow tutorial 并从 Gdrive 加载我自己的数据集来进行图像分类。 现在我想绘制混淆矩阵。首先,我预测了验证数据集的标签:
val_preds = model.predict(val_ds)
但我不确定如何获取原始标签以将预测与它们进行比较。我尝试了不同的方法,但准确度非常低,所以我知道标签不应该是这样。
val_ds_labels = np.concatenate([y for x, y in val_ds], axis=0)
这给了我 0.067 的准确度,而下面给了我大约 0.70 的准确度。
epochs = 10
history=model.fit(train_ds, epochs=epochs, validation_data=val_ds)
这是我创建验证和训练数据集的方式:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
"images",
validation_split=0.2,
subset="training",
seed=123,
image_size=image_size,
batch_size=batch_size,
label_mode='int'
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
"images",
validation_split=0.2,
subset="validation",
seed=123,
image_size=image_size,
batch_size=batch_size,
label_mode='int'
)
train_ds = train_ds.prefetch(buffer_size=32)
val_ds = val_ds.prefetch(buffer_size=32)
然后创建模型并编译它:
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseTopKCategoricalAccuracy(k=1)],
)
合身
epochs = 10
history=model.fit(train_ds, epochs=epochs, validation_data=val_ds)
我有 22 个标签。
val_preds = model.predict(val_ds)
【问题讨论】:
标签: tensorflow keras deep-learning multiclass-classification image-classification