【发布时间】:2020-09-30 07:30:17
【问题描述】:
我有使用 CIFAR-100 数据集创建的 tensorflow Dataset 对象。我需要访问 Dataset 对象内的火车标签 TensorSpec。由于 TensorSliceDataset 对象不支持index.如何访问每个TensorSpec 并遍历其中的值。
(train_data, train_labels), (test_data, test_labels) = cifar100.load_data(label_mode='fine')
with open('data/cifar100/cifar100_labels.json', 'r') as j:
cifar_labels = json.load(j)
dataset = tf.data.Dataset.from_tensor_slices((train_data,train_labels))
print(train_dataset.element_spec)
# (TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None),
# TensorSpec(shape=(1,), dtype=tf.int64, name=None))
【问题讨论】:
标签: python loops tensorflow tensor