【发布时间】:2019-10-05 20:33:58
【问题描述】:
我想对张量进行切片,如下面的 numpy 切片。我该怎么做?
# numpy array
a = np.reshape(np.arange(60), (3,2,2,5))
idx = np.array([0, 1, 0])
N = np.shape(a)[0]
mask = a[np.arange(N),:,:,idx]
# I have tried several solutions, but only the following success.
# tensors
import tensorflow as tf
import numpy as np
a = tf.cast(tf.constant(np.reshape(np.arange(60), (3,2,2,5))), tf.int32)
idx2 = tf.constant([0, 1, 0])
fn = lambda i: a[i][:,:,idx2[i]]
idx = tf.range(tf.shape(a)[0])
masks = tf.map_fn(fn, idx)
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(tf.shape(masks)))
print(sess.run(masks))
有没有更简单的方法来实现这一点?
我可以使用函数tf.gather 或tf.gather_nd 来实现吗?
非常感谢!
【问题讨论】:
标签: python tensorflow slice