【问题标题】:cast tensorflow 2.0 BatchDataset to numpy array将 tensorflow 2.0 BatchDataset 转换为 numpy 数组
【发布时间】:2019-09-04 15:44:30
【问题描述】:

我有这个代码:

(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()

train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)

print(train_dataset, type(train_dataset), test_dataset, type(test_dataset))

我想将这两个BatchDataset 变量转换为numpy arrays,我可以轻松做到吗?我正在使用TF 2.0,但我刚刚找到了将tf.dataTF 1.0 一起转换的代码

【问题讨论】:

    标签: python tensorflow casting


    【解决方案1】:

    对数据集进行批处理后,最后一批的形状可能与其余批次的形状不同。例如,如果您的数据集中总共有 100 个元素,并且您的批次大小为 6,则最后一批的大小仅为 4。(100 = 6 * 16 + 4)。

    因此,在这种情况下,您将无法直接将数据集转换为 numpy。因此,您必须在批处理方法中使用 drop_remainder 参数为 True。如果大小不正确,它将丢弃最后一批。

    之后,我附上了如何将数据集转换为 Numpy 的代码。

    import tensorflow as tf
    import numpy as np
    
    (train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
    
    TRAIN_BUF=1000
    BATCH_SIZE=64
    
    train_dataset = tf.data.Dataset.from_tensor_slices(train_images).
                              shuffle(TRAIN_BUF).batch(BATCH_SIZE, drop_remainder=True)
    test_dataset = tf.data.Dataset.from_tensor_slices(test_images).
                              shuffle(TRAIN_BUF).batch(BATCH_SIZE, drop_remainder=True)
    
    # print(train_dataset, type(train_dataset), test_dataset, type(test_dataset))
    
    train_np = np.stack(list(train_dataset))
    test_np = np.stack(list(test_dataset))
    print(type(train_np), train_np.shape)
    print(type(test_np), test_np.shape)
    

    输出:

    <class 'numpy.ndarray'> (937, 64, 28, 28)
    <class 'numpy.ndarray'> (156, 64, 28, 28)
    

    【讨论】:

    • 这段代码对我来说似乎是正确的,但我试图在 google colab 上运行它,它卡在了将数据转换为列表的线上
    • 目前,Google Colab 默认仍使用 TensorFlow 1.14 版本。因此,您必须通过运行!pip install tensorflow==2.0.0-rc0 手动安装 TF2.0。之后,您将不会遇到提到的冻结问题。
    猜你喜欢
    • 2022-07-20
    • 2021-01-12
    • 2021-07-02
    • 2020-12-31
    • 2018-09-09
    • 1970-01-01
    • 2018-07-22
    • 2016-03-09
    • 2021-08-08
    相关资源
    最近更新 更多