对于纯 numpy 解决方案,您可以执行以下操作:
-
使用np.unique分别获取x和y中的唯一值和对应的索引:
# sorted unique values in x and y and the indices corresponding to their first
# occurrences, such that u_x == x[u_idx_x]
u_x, u_idx_x = np.unique(x, return_index=True)
u_y, u_idx_y = np.unique(y, return_index=True)
-
使用np.intersect1d 查找唯一值的交集:
# we can assume_unique, which can be faster for large arrays
i_xy = np.intersect1d(u_x, u_y, assume_unique=True)
-
最后,使用np.in1d 仅选择与x 或y 中的唯一值对应的索引,这些索引也恰好在x 和y 的交集处:
# it is also safe to assume_unique here
i_idx_x = u_idx_x[np.in1d(u_x, i_xy, assume_unique=True)]
i_idx_y = u_idx_y[np.in1d(u_y, i_xy, assume_unique=True)]
将所有这些整合到一个函数中:
def intersect_indices(x, y):
u_x, u_idx_x = np.unique(x, return_index=True)
u_y, u_idx_y = np.unique(y, return_index=True)
i_xy = np.intersect1d(u_x, u_y, assume_unique=True)
i_idx_x = u_idx_x[np.in1d(u_x, i_xy, assume_unique=True)]
i_idx_y = u_idx_y[np.in1d(u_y, i_xy, assume_unique=True)]
return i_idx_x, i_idx_y
例如:
x = np.array([4, 1, 10, 5, 8, 13, 11])
y = np.array([20, 5, 4, 9, 11, 7, 25])
i_idx_x, i_idx_y = intersect_indices(x, y)
print(i_idx_x, i_idx_y)
# (array([0, 3, 6]), array([2, 1, 4]))
速度测试:
In [1]: k = 1000000
In [2]: %%timeit x, y = np.random.randint(k, size=(2, k))
intersect_indices(x, y)
....:
1 loops, best of 3: 597 ms per loop
更新:
我最初错过了这样一个事实,即在您的情况下,x 和 y 都只包含唯一值。考虑到这一点,使用间接排序可能会做得更好:
def intersect_indices_unique(x, y):
u_idx_x = np.argsort(x)
u_idx_y = np.argsort(y)
i_xy = np.intersect1d(x, y, assume_unique=True)
i_idx_x = u_idx_x[x[u_idx_x].searchsorted(i_xy)]
i_idx_y = u_idx_y[y[u_idx_y].searchsorted(i_xy)]
return i_idx_x, i_idx_y
这是一个更实际的测试用例,其中x 和y 都包含唯一(但部分重叠)的值:
In [1]: n, k = 10000000, 1000000
In [2]: %%timeit x, y = (np.random.choice(n, size=k, replace=False) for _ in range(2))
intersect_indices(x, y)
....:
1 loops, best of 3: 593 ms per loop
In [3]: %%timeit x, y = (np.random.choice(n, size=k, replace=False) for _ in range(2))
intersect_indices_unique(x, y)
....:
1 loops, best of 3: 453 ms per loop
@Divakar 的解决方案在性能方面非常相似:
In [4]: %%timeit x, y = (np.random.choice(n, size=k, replace=False) for _ in range(2))
searchsorted_based(x, y)
....:
1 loops, best of 3: 472 ms per loop