【发布时间】:2021-02-04 08:57:42
【问题描述】:
我是 tensorflow 的新手,我正在尝试使用 tensorflow.Dataset 提供一些数据。我使用具有 8 个不同类别的 Cityscape 数据集。这是我的代码:
import os
import cv2
import numpy as np
import tensorflow as tf
H = 256
W = 256
id2cat = np.array([0,0,0,0,0,0,0, 1,1,1,1, 2,2,2,2,2,2, 3,3,3,3, 4,4, 5, 6,6, 7,7,7,7,7,7,7,7,7])
def readImage(x):
x = cv2.imread(x, cv2.IMREAD_COLOR)
x = cv2.resize(x, (W, H))
x = x / 255.0
x = x.astype(np.float32)
return x
def readMask(path):
mask = cv2.imread(path, 0)
mask = cv2.resize(mask, (W, H))
mask = id2cat[mask]
return mask.astype(np.int32)
def preprocess(x, y):
def f(x, y):
image = readImage(x)
mask = readMask(y)
return image, mask
image, mask = tf.numpy_function(f, [x, y], [tf.float32, tf.int32])
mask = tf.one_hot(mask, 3, dtype=tf.int32)
image.set_shape([H, W, 3])
mask.set_shape([H, W, 3])
return image, mask
def tf_dataset(x, y, batch=8):
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(buffer_size=5000)
dataset = dataset.map(preprocess)
dataset = dataset.batch(batch)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)
return dataset
def loadCityscape():
trainPath = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets\\Cityscape\\train')
imagesPath = os.path.join(trainPath, 'images')
maskPath = os.path.join(trainPath, 'masks')
images = []
masks = []
print('Loading images and masks for Cityscape dataset...')
for image in os.listdir(imagesPath):
images.append(readImage(os.path.join(imagesPath, image)))
for mask in os.listdir(maskPath):
if 'label' in mask:
masks.append(readMask(os.path.join(maskPath, mask)))
print('Loaded {} images\n'.format(len(images)))
return images, masks
images, masks = loadCityscape()
dataset = tf_dataset(images, masks, batch=8)
print(dataset)
最后的打印(数据集)显示:
<PrefetchDataset shapes: ((None, 256, 256, 3), (None, 256, 256, 3)), types: (tf.float32, tf.int32)>
为什么我得到的是 (None, 256, 256, 3) 而不是 (8, 256, 256, 3)?我对如何迭代这个数据集也有一些疑问。
非常感谢。
【问题讨论】:
标签: python tensorflow