【问题标题】:Index operation in TensorFlowTensorFlow 中的索引操作
【发布时间】:2016-03-06 23:32:56
【问题描述】:

我在对一些数据做批量标注的时候,有一个变量用来记录所有的计算结果:

p_all = tf.Variable(tf.zeros([batch_num, batch_size]), name = "probability");

在计算中,我有一个循环来处理每个批次:

for i in range(batch_num):
    feed = {x: testDS.test.next_batch(batch_size)}
    sess.run(p_each_batch, feed_dict=feed)

如何将p_each_bach 的值复制到p_all 中?

为了更清楚,我想要类似的东西:

... ...
p_all[batch_index,:] = p_each_batch
for i in range(batch_num):
    feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
    sess.run(p_all, feed_dict=feed)

我怎样才能使这些代码真正起作用?

【问题讨论】:

    标签: indexing tensorflow


    【解决方案1】:

    由于p_alltf.Variable,您可以使用tf.scatter_update() 操作来更新每批中的各个行:

    # Equivalent to `p_all[batch_index, :] = p_each_batch`
    update_op = tf.scatter_update(p_all,
                                  tf.expand_dims(batch_index, 0),
                                  tf.expand_dims(p_each_batch, 0)) 
    
    for i in range(batch_num):
        feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
        sess.run(update_op, feed_dict=feed)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2014-09-25
      • 2011-03-24
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-01-26
      • 1970-01-01
      相关资源
      最近更新 更多