【问题标题】:BatchDataSet: get img array and labelsBatchDataSet:获取 img 数组和标签
【发布时间】:2022-01-25 08:45:26
【问题描述】:

这是我之前创建的适合模型的批处理数据集:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_path,
    label_mode = 'categorical', #it is used for multiclass classification. It is one hot encoded labels for each class
    validation_split = 0.2,     #percentage of dataset to be considered for validation
    subset = "training",        #this subset is used for training
    seed = 1337,                # seed is set so that same results are reproduced
    image_size = img_size,      # shape of input images
    batch_size = batch_size,    # This should match with model batch size
)




valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_path,
    label_mode ='categorical',
    validation_split = 0.2,
    subset = "validation",      #this subset is used for validation
    seed = 1337,
    image_size = img_size,
    batch_size = batch_size,
)

如果我运行一个 for 循环,我可以访问 img 数组和标签:

for images, labels in train_ds:
    print(labels)

但如果我尝试像这样访问它们:

尝试 1)

images, labels = train_ds

我得到以下值错误:ValueError: too many values to unpack (expected 2)

尝试 2:

如果我尝试这样解包:

images = train_ds[:,0] # get the 0th column of all rows 
labels = train_ds[:,1] # get the 1st column of all rows 

我收到以下错误:TypeError: 'BatchDataset' object is not subscriptable

有没有办法让我在不通过 for 循环的情况下提取标签和图像?

【问题讨论】:

  • train_ds的数据类型是什么?
  • 它的一个:tensorflow.python.data.ops.dataset_ops.BatchDataset

标签: python tensorflow keras deep-learning tensorflow-datasets


【解决方案1】:

只需取消批处理数据集并将数据转换为列表:

import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" 
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir) 
batch_size = 32 
train_ds = tf.keras.utils.image_dataset_from_directory( 
     data_dir, validation_split=0.2, subset="training", 
     seed=123, batch_size=batch_size) 

train_ds = train_ds.unbatch()
images = list(train_ds.map(lambda x, y: x))
labels = list(train_ds.map(lambda x, y: y))

print(len(labels))
print(len(images))
Found 3670 files belonging to 5 classes.
Using 2936 files for training. 
2936 
2936

【讨论】:

    【解决方案2】:

    对于您的具体情况,train_ds 将是一个张量对象,其中每个元素都是一个元组:(image,label)

    可能会尝试类似:

    # train_ds = [(image,label) …]
    images = train_ds[:,0] # get the 0th column of all rows 
    labels = train_ds[:,1] # get the 1st column of all rows 
    
    

    【讨论】:

    • 感谢您的回复!我曾尝试过类似的方法,但出现以下错误:TypeError: 'BatchDataset' object is not subscriptable , the dataset is a 'tensorflow.python.data.ops.dataset_ops.BatchDataset'
    • 我的错!我只是假设您的数据集的数据类型将是一个数组。嗯……如果你想从标签中分离出图像,我想最好的方法是使用 for 循环(就像你之前做的那样)。不过我可能错了!
    猜你喜欢
    • 2022-01-24
    • 2021-01-29
    • 2012-05-19
    • 1970-01-01
    • 1970-01-01
    • 2021-07-02
    • 2010-10-04
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多