【问题标题】:How to replace a value within a tensor by indices?如何用索引替换张量中的值?
【发布时间】:2017-03-19 05:56:42
【问题描述】:

以下代码通过索引向张量内的特定位置添加一些内容(感谢@mrry 的回答here)。

indices = [[1, 1]]  # A list of coordinates to update.
values = [1.0]  # A list of values corresponding to the respective
            # coordinate in indices.
shape = [3, 3]  # The shape of the corresponding dense tensor, same as `c`.
delta = tf.SparseTensor(indices, values, shape)

例如,给定这个 -

c = tf.constant([[0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0]])

它会在 [1, 1] 处加 1,结果是

[[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 0.0]])

问题 - 是否可以在特定位置替换值而不是在该位置添加?如果在 tensorflow 中不可以,在任何其他类似的库中是否可以?

例如,

鉴于此-

[[4.0, 43.1.0, 45.0],
[2.0, 22.0, 6664.0],
[-4543.0, 0.0, 43.0]])

有没有办法将 [1, 1] 处的 22 替换为(比如说)45,结果如下?

[[4.0, 43.1.0, 45.0],
[2.0, 45.0, 6664.0],
[-4543.0, 0.0, 43.0]])

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    这很笨重,但它确实替换了张量中的值。它基于您提到的this answer

    # inputs
    inputs = tf.placeholder(shape = [None, None], dtype = tf.float32)  # tensor with values to replace
    indices = tf.placeholder(shape = [None, 2], dtype = tf.int64)  # coordinates to be updated
    values = tf.placeholder(shape = [None], dtype = tf.float32)  # values corresponding to respective coordinates in "indices"
    
    # set elements in "indices" to 0's
    maskValues = tf.tile([0.0], [tf.shape(indices)[0]])  # one 0 for each element in "indices"
    mask = tf.SparseTensor(indices, maskValues, tf.shape(inputs, out_type = tf.int64))
    maskedInput = tf.multiply(inputs, tf.sparse_tensor_to_dense(mask, default_value = 1.0))  # set values in coordinates in "indices" to 0's, leave everything else intact
    
    # replace elements in "indices" with "values"
    delta = tf.SparseTensor(indices, values, tf.shape(inputs, out_type = tf.int64))
    outputs = tf.add(maskedInput, tf.sparse_tensor_to_dense(delta))  # add "values" to elements in "indices" (which are 0's so far)
    

    它的作用:

    1. 将需要替换的位置的输入元素设置为 0。
    2. 将所需的值添加到这些 0 中(直接来自 here)。

    通过运行检查:

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        ins = np.array([[4.0, 43.0, 45.0], [2.0, 22.0, 6664.0], [-4543.0, 0.0, 43.0]])
        ind = [[1, 1]]
        vals = [45]
        outs = sess.run(outputs, feed_dict = { inputs: ins, indices: ind, values: vals })
        print(outs)
    

    输出:

    [[ 4.000e+00  4.300e+01  4.500e+01]
     [ 2.000e+00  4.500e+01  6.664e+03]
     [-4.543e+03  0.000e+00  4.300e+01]]
    

    与许多otherwise great answers 不同,这个功能超越了tf.Variable()s。

    【讨论】:

      【解决方案2】:

      根据自己的值或索引更新张量的一个简单选项是使用tf.wheretf.tensor_scatter_nd_update

      import tensorflow as tf
      
      x = tf.constant([[4.0, 43.0, 45.0],
                       [2.0, 22.0, 6664.0],
                       [-4543.0, 0.0, 43.0]])
      value = 45.0
      indices = [1, 1]
      
      by_indices = tf.tensor_scatter_nd_update(x, [indices], [value])
      tf.print('Using indices\n', by_indices, '\n')
      
      by_value = tf.where(tf.equal(x, 22.0), value, x)
      tf.print('Using value\n', by_value)
      
      Using indices
       [[4 43 45]
       [2 45 6664]
       [-4543 0 43]] 
      
      Using value
       [[4 43 45]
       [2 45 6664]
       [-4543 0 43]]
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2019-09-17
        • 1970-01-01
        • 2021-12-30
        • 1970-01-01
        • 2017-06-21
        • 1970-01-01
        • 2020-10-16
        • 2021-11-03
        相关资源
        最近更新 更多