【发布时间】:2020-10-04 21:37:46
【问题描述】:
我正在将 python 代码从 keras 命名空间转换为 tf.keras。它训练 Resnet50。 新的 Model.fit() 方法找不到我的简单生成器的适配器,validation_data 甚至不再支持生成器。所以我正在尝试使用 tensorflow.data.Dataset.from_generator 方法将其转换为 Dataset。
图像是灰度图像并以原始字节存储 - 一个字节对应一个像素。生成器有这样的行
def __next__( self ):
return self.next()
def __call__( self ):
return self.next()
def next( self ):
#reading files
...
resultLabels = numpy.zeros( ( count, len( classes ) ), "float32" )
resultImages = numpy.zeros( ( count, patchSize, patchSize, 3 ), "float32" )
#filling result with images and labels
...
fileBytes = numpy.reshape( numpy.fromfile( self.ImageLabelsAndPaths[i][1], "uint8" ), (patchSize, patchSize), "F" ).astype( "float32" )
imageWithChannels = numpy.zeros( ( patchSize, patchSize, 3 ), "float32" )
# Because Resnet50 requires RGB images and we have grayscale
imageWithChannels[:,:,0] = fileBytes
imageWithChannels[:,:,1] = fileBytes
imageWithChannels[:,:,2] = fileBytes
resultImages[i - cursor] = imageWithChannels
return ( resultImages, resultLabels )
所以 resultImages 是一个长度为 batch_size=16 的数组,其中包含图像像素数组。 Numpy.shape 是 (16, 256, 256, 3),resultLabels 形状是 (16, 3) - 目前有 3 个类。
接下来我将其转换为数据集
trainGenerator = FileIterator( "train" )
trainDataset = tf.data.Dataset.from_generator( trainGenerator, (tf.float32, tf.float32), (tf.TensorShape([batchSize, patchSize, patchSize, 3]), tf.TensorShape([batchSize, len(classes)]) ) )
validationGenerator = FileIterator( "validate" )
validationDataset = tf.data.Dataset.from_generator( validationGenerator, (tf.float32, tf.float32), (tf.TensorShape([batchSize, patchSize, patchSize, 3]), tf.TensorShape([batchSize, len(classes)]) ) )
但我遇到了错误
TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32), but the yielded element was [[[[185. 185. 185.]
[158. 158. 158.]
[145. 145. 145.]
...
Dataset.from_generator 的代码示例有一个数组作为元组中的第二项和类似的 output_types=(tf.int64, tf.int64)。我猜它在那里工作。
尝试将数组添加到类型会导致另一个错误
TypeError: unhashable type: 'list'
我应该改变什么才能让它工作?
【问题讨论】:
标签: python-3.x keras neural-network tensorflow-datasets