【发布时间】:2014-10-06 18:52:50
【问题描述】:
这是一个间接索引问题。
可以通过列表推导来解决。
问题是是否或如何在 numpy 中解决它,
什么时候
data.shape 是 (T,N)
和
c.shape 是 (T,K)
并且c的每个元素都是一个介于0和N-1之间的int,即,
c 的每个元素都旨在引用来自 data 的列号。
目标是在哪里获取out
out.shape = (T,K)
对于0..(T-1) 中的每个i
行out[i] = [ data[i, c[i,0]] , ... , data[i, c[i,K-1]] ]
具体例子:
data = np.array([\
[ 0, 1, 2],\
[ 3, 4, 5],\
[ 6, 7, 8],\
[ 9, 10, 11],\
[12, 13, 14]])
c = np.array([
[0, 2],\
[1, 2],\
[0, 0],\
[1, 1],\
[2, 2]])
out should be out = [[0, 2], [4, 5], [6, 6], [10, 10], [14, 14]]
out 的第一行是 [0,2] 因为选择的列是 c 的第 0 行给出的,它们是 0 和 2,而第 0 和 2 列的 data[0] 是 0 和 2。
out的第二行是[4,5],因为选择的列是c的第1行给定的,分别是1和2,而1和2列的data[1]是4和5。
Numpy 花式索引似乎并没有以明显的方式解决这个问题,因为使用 c (例如 data[c]、np.take(data,c,axis=1) )索引数据总是会产生一个 3 维数组。
列表推导可以解决它:
out = [ [data[rowidx,i1],data[rowidx,i2]] for (rowidx, (i1,i2)) in enumerate(c) ]
如果 K 为 2,我想这还可以。如果 K 是可变的,这不是很好。
必须为每个值 K 重写列表推导,因为它会按 c 的每一行展开从 data 中挑选的列。它也违反了 DRY。
有完全基于numpy的解决方案吗?
【问题讨论】: