【发布时间】:2017-09-28 13:01:11
【问题描述】:
这个问题和this one一样,但是是针对Tensorflow的。
假设我有“行”的二维张量,并且想要从每一行中选择第 i 个元素并组成这些元素的结果列,在选择器张量中包含 i-s,如下所示
import tensorflow as tf
import numpy as np
rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64)
# gives
# array([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11],
# [12, 13, 14],
# [15, 16, 17],
# [18, 19, 20],
# [21, 22, 23],
# [24, 25, 26],
# [27, 28, 29]])
selector = tf.get_variable("selector", [10,1], dtype=tf.int8, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]]))
result_of_selection = ...
# should be
# array([[ 0],
# [ 4],
# [ 6],
# [11],
# [13],
# [15],
# [18],
# [23],
# [26],
# [28]])
我该怎么做?
更新
我是这样写的(感谢@Psidom)
import tensorflow as tf
import numpy as np
rows = tf.constant(np.arange(10*3).reshape(10,3), dtype=tf.float64)
# selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([0, 1, 0, 2, 1, 0, 0, 2, 2, 1], dtype=tf.int32))
# selector = tf.expand_dims(selector, axis=1)
selector = tf.get_variable("selector", dtype=tf.int32, initializer=tf.constant([[0], [1], [0], [2], [1], [0], [0], [2], [2], [1]], dtype=tf.int32))
ordinals = tf.reshape(tf.range(rows.shape[0]), (-1,1))
#idx = tf.concat([selector, ordinals], axis=1)
idx = tf.stack([selector, ordinals], axis=-1)
result = tf.gather_nd(rows, idx)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
rows_value, result_value = sess.run([rows, result])
print("rows_value: " + str(rows_value))
print("selector_value: " + str(result_value))
它给了
rows_value: [[ 0. 1. 2.]
[ 3. 4. 5.]
[ 6. 7. 8.]
[ 9. 10. 11.]
[ 12. 13. 14.]
[ 15. 16. 17.]
[ 18. 19. 20.]
[ 21. 22. 23.]
[ 24. 25. 26.]
[ 27. 28. 29.]]
selector_value: [[ 0.]
[ 4.]
[ 2.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]]
即不正确。
更新 2
固定线路
idx = tf.stack([ordinals, selector], axis=-1)
顺序不正确。
【问题讨论】:
标签: python tensorflow