【问题标题】:Tensorflow dataset batching for complex data用于复杂数据的 TensorFlow 数据集批处理
【发布时间】:2018-09-03 05:59:30
【问题描述】:

我尝试按照此链接中的示例进行操作:

https://www.tensorflow.org/programmers_guide/datasets

但我完全不知道如何运行会话。我理解第一个参数是要运行的操作,而 feed_dict 是占位符(我的理解是训练或测试数据集的批次),

所以,这是我的代码:

batch_size = 100
handle_mix = tf.placeholder(tf.float64, shape=[])
handle_src0 = tf.placeholder(tf.float64, shape=[])
handle_src1 = tf.placeholder(tf.float64, shape=[])
handle_src2 = tf.placeholder(tf.float64, shape=[])
handle_src3 = tf.placeholder(tf.float64, shape=[])

我从 mp4 音轨和词干创建数据集,读取混合和源量级,并将它们填充以适合批处理

dataset = tf.data.Dataset.from_tensor_slices(
    {"x_mixed":padded_lbl, "y_src0": padded_src[0], "y_src1":      
    padded_src[1],"y_src2": padded_src[1], "y_src3": padded_src[1]})
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)

从我应该做的例子中:

next_element = iterator.get_next()

training_init_op = iterator.make_initializer(dataset)
for _ in range(20):
    # Initialize an iterator over the training dataset.
    sess.run(training_init_op)
    for _ in range(100):
        sess.run(next_element)

但是,我有损失、汇总和优化器操作,需要将数据作为批次提供,下面的另一个示例如下:

l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict=    {handle_mix: batch_mix, handle_src0: batch_src0, handle_src1: batch_src1, handle_src2: batch_src2, handle_src3: batch_src3})

所以我想到了类似的东西:

batch_mix、batch_src0、batch_src1、batch_src2、batch_src3 = data.train.next_batch(batch_size) 或者可能是单独运行以先获取批次,然后按上述方式运行优化,例如:

batch_mix, batch_src0, batch_src1, batch_src2, batch_src3 = sess.run(next_element)
l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict={handle_mix: batch_mix, handle_src0: batch_src0, handle_src1: batch_src1, handle_src2: batch_src2, handle_src3: batch_src3})

最后一次尝试返回了在 tf.data.Dataset.from_tensor_slices 中创建的批次的字符串名称(“x_mixed”、“y_src0”、...等),并且未能在会话中转换为 tf.float64 占位符.

能否请您告诉我如何创建此数据集,首先张量切片的结构可能存在错误,然后如何对它们进行批处理,

非常感谢,

【问题讨论】:

    标签: tensorflow iterator dataset batching


    【解决方案1】:

    问题是您在从张量切片创建数据集时将数据打包到字典中。这将导致iterator.get_next() 也将每个批次作为字典返回。如果我们做类似的事情

    d = {"a": 1, "b": 2}
    k1, k2 = d
    

    我们得到k1 == "a"k2 == "b"(或者由于字典键无序而反过来)。也就是说,您尝试解压缩 sess.run(next_element) 的结果只会为您提供 dict 键,而您对 dict values (张量)感兴趣。这应该可以代替:

    next_element = iterator.get_next()
    x_mixed = next_element["x_mixed"]
    y_src0 = next_element["y_src0"]
    ...
    

    如果您随后基于变量x_mixed 等构建模型,它应该可以正常工作。请注意,使用tf.data API,您不需要占位符! Tensorflow 将看到您的模型输出需要例如x_mixed,它来自 iterator.get_next(),所以只要你尝试 sess.run() 你的损失函数/优化器等,它就会简单地执行这个操作。如果你对占位符更满意,你当然可以继续使用它们,只是记得正确解压字典。这应该是正确的:

    batch_dict = sess.run(next_element)
    l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict={handle_mix: batch_dict["x_mixed"], ... })
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-10-30
      • 2018-08-17
      • 2020-10-29
      • 1970-01-01
      • 1970-01-01
      • 2018-08-24
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多