【问题标题】:How can I use Tensorflows scatter_nd with slices?如何将 Tensorflows scatter_nd 与切片一起使用?
【发布时间】:2019-04-02 11:23:35
【问题描述】:

我正在尝试仅优化变量的一部分。我找到了this 看似有用的答案。

但是我的变量是一个图像,我只想更改它的一部分,所以我试图将代码扩展到更多维度。这似乎工作正常:

import tensorflow as tf
import tensorflow.contrib.opt as opt

X = tf.Variable([[1.0, 2.0], [3.0, 4.0]])

# the next two lines need to change because
# manually specifying the values is not feasible
indexes = tf.constant([[0, 0], [1, 0]])
updates = [X[0, 0], X[1, 0]]

part_X = tf.scatter_nd(indexes, updates, [2, 2])
X_2 = part_X + tf.stop_gradient(-part_X + X)
Y = tf.constant([[2.5, -3.5], [5.5, -7.5]])
loss = tf.reduce_sum(tf.squared_difference(X_2, Y))
opt = opt.ScipyOptimizerInterface(loss, [X])

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    opt.minimize(sess)
    print("X: {}".format(X.eval()))

但是,由于我的图像尺寸和我想选择的区域要大得多,因此手动指定所有索引是不可行的。我想知道如何使用切片或范围分配来做到这一点。

【问题讨论】:

    标签: python tensorflow range slice


    【解决方案1】:

    你可以这样做:

    import tensorflow as tf
    
    # Input with size (50, 100)
    X = tf.Variable([[0] * 100] * 50)
    # Selected slice
    row_start = 10
    row_end = 30
    col_start = 20
    col_end = 50
    # Make indices from meshgrid
    indexes = tf.meshgrid(tf.range(row_start, row_end),
                          tf.range(col_start, col_end), indexing='ij')
    indexes = tf.stack(indexes, axis=-1)
    # Take slice
    updates = X[row_start:row_end, col_start:col_end]
    # Build tensor with "filtered" gradient
    part_X = tf.scatter_nd(indexes, updates, tf.shape(X))
    X_2 = part_X + tf.stop_gradient(-part_X + X)
    # Continue as before...
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2016-02-24
      • 2016-08-11
      • 1970-01-01
      • 2012-05-24
      相关资源
      最近更新 更多