【问题标题】:How to Iterate over TensorSliceDataset object in Tensorflow如何在 Tensorflow 中迭代 TensorSliceDataset 对象
【发布时间】:2020-09-30 07:30:17
【问题描述】:

我有使用 CIFAR-100 数据集创建的 tensorflow Dataset 对象。我需要访问 Dataset 对象内的火车标签 TensorSpec。由于 TensorSliceDataset 对象不支持index.如何访问每个TensorSpec 并遍历其中的值。

(train_data, train_labels), (test_data, test_labels) = cifar100.load_data(label_mode='fine')
with open('data/cifar100/cifar100_labels.json', 'r') as j:
    cifar_labels = json.load(j)

dataset = tf.data.Dataset.from_tensor_slices((train_data,train_labels))
print(train_dataset.element_spec)

# (TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 
# TensorSpec(shape=(1,), dtype=tf.int64, name=None))

【问题讨论】:

    标签: python loops tensorflow tensor


    【解决方案1】:

    你可以把标签变成一个数组:

    import tensorflow as tf
    
    (train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.mnist.load_data()
    
    dataset = tf.data.Dataset.from_tensor_slices((train_data,train_labels))
    
    next(dataset.batch(60_000).as_numpy_iterator())[1]
    
    array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
    

    这是你要找的吗?

    【讨论】:

    猜你喜欢
    • 2020-05-01
    • 2019-02-10
    • 2020-02-26
    • 2018-10-25
    • 1970-01-01
    • 2021-04-09
    • 1970-01-01
    • 2017-03-20
    • 2018-08-02
    相关资源
    最近更新 更多