【问题标题】:Update slice of tensor in tensorflow更新张量流中的张量切片
【发布时间】:2019-01-09 20:28:17
【问题描述】:

我想更新一个 3 维的张量切片。关注How to do slice assignment in Tensorflow 我会做类似的事情

import tensorflow as tf

with tf.Session() as sess:
    init_val = tf.Variable(tf.zeros((2, 3, 3)))
    indices = tf.constant([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1]])
    update = tf.scatter_nd_add(init_val, indices, tf.ones(4))

    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(update))

这可行,但由于我的实际问题更复杂,我想通过定义切片的开头和大小以某种方式自动生成索引集,例如如果您使用tf.slice(...)。你有什么想法?提前致谢!

我使用的是 TensorFlow 1.12,它是当前最新版本。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    tf.strided_slice 支持传递 var 参数来指示切片所引用的变量,因此当您传递它时,它将返回一个可分配对象(我不确定他们为什么不这样做,具体取决于输入的类型,但无论如何)。你可以这样做:

    import tensorflow as tf
    import numpy as np
    
    var = tf.Variable(np.ones((3, 4), dtype=np.float32))
    s = tf.strided_slice(var, [0, 2], [2, 3], var=var, name='var_slice')
    s2 = s.assign([[2], [3]])
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        print(sess.run(s2))
    

    输出:

    [[1. 1. 2. 1.]
     [1. 1. 3. 1.]
     [1. 1. 1. 1.]]
    

    请注意,在tf.strided_slice 中,您提供了开始和结束索引(不包括结束),与在tf.slice 中,您提供开始和大小不同。此外,就目前的代码而言,您必须为切片或分配操作提供名称值(我认为这应该是一个错误并且发生,因为这部分 API 几乎完全在内部使用)。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-12-01
      • 2018-03-15
      • 2019-11-05
      • 1970-01-01
      • 2019-05-30
      • 2018-01-31
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多