【发布时间】:2018-12-05 05:21:33
【问题描述】:
我希望创建一个管道来向神经网络提供非标准文件(例如扩展名为 *.xxx)。 目前我的代码结构如下:
1) 我定义了一个查找训练文件的路径列表
2) 我定义了一个包含这些路径的 tf.data.Dataset 对象的实例
3) 我将一个 python 函数映射到数据集,该函数采用每个路径并返回关联的 numpy 数组(从 pc 上的文件夹加载);这个数组是一个维度为 [256, 256, 192] 的矩阵。
4) 我定义了一个可初始化的迭代器,然后在网络训练期间使用它。
我怀疑我提供给网络的批次的大小。我想向网络提供大小为 64 的批次。我该怎么办? 例如,如果我将函数 train_data.batch(b_size) 与 b_size = 1 一起使用,结果是当迭代时,迭代器给出一个形状为 [256, 256, 192] 的元素;如果我想用这个数组的 64 个切片来喂神经网络呢?
这是我的代码的摘录:
with tf.name_scope('data'):
train_filenames = tf.constant(list_of_files_train)
train_data = tf.data.Dataset.from_tensor_slices(train_filenames)
train_data = train_data.map(lambda filename: tf.py_func(
self._parse_xxx_data, [filename], [tf.float32]))
train_data.shuffle(buffer_size=len(list_of_files_train))
train_data.batch(b_size)
iterator = tf.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)
input_data = iterator.get_next()
train_init = iterator.make_initializer(train_data)
[...]
with tf.Session() as sess:
sess.run(train_init)
_ = sess.run([self.train_op])
提前致谢
---------
我在下面的 cmets 中发布了我的问题的解决方案。我仍然很乐意收到有关可能改进的任何意见或建议。谢谢;)
【问题讨论】:
标签: python tensorflow iterator dataset