【发布时间】:2019-04-24 21:48:18
【问题描述】:
我在操作 Numpy 数组,代码如下:
z[np.arange(n), y]
其中 z 是二维数组,y 是一维数组。此外,z.shape[0] == y.shape[0] == n.
我怎样才能做与 TensorFlow 张量等效的事情?
【问题讨论】:
标签: arrays python-3.x numpy tensorflow
我在操作 Numpy 数组,代码如下:
z[np.arange(n), y]
其中 z 是二维数组,y 是一维数组。此外,z.shape[0] == y.shape[0] == n.
我怎样才能做与 TensorFlow 张量等效的事情?
【问题讨论】:
标签: arrays python-3.x numpy tensorflow
您可以使用tf.gather_nd 来获取您想要的索引。
import numpy as np
import tensorflow as tf
# Numpy implementation
n = 3
z = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
y = np.array([0, 1, 1])
assert z.shape[0] == y.shape[0] == n
np_out = z[np.arange(n), y]
# TF implementation
tf.reset_default_graph()
range_t = tf.range(n) # Equiv to np.arange
x_y = tf.stack([range_t, y], axis=1) # Get (x,y) as a tuple
pick_by_index_from_z = tf.gather_nd(z, x_y) # Pick the right values from z
with tf.Session() as sess:
tf_out = sess.run(pick_by_index_from_z)
# The np and tf values should be the same
assert (np_out == tf_out).all()
print('z:')
print(z)
print('\nnp_out:')
print(np_out)
print('\ntf_out:')
print(tf_out)
这给出了输出:
z:
[[1 2 3]
[4 5 6]
[7 8 9]]
np_out:
[1 5 8]
tf_out:
[1 5 8]
【讨论】: