【问题标题】:Fast numpy row slicing on a matrix矩阵上的快速 numpy 行切片
【发布时间】:2020-05-17 12:55:14
【问题描述】:

我有以下问题:我有一个大小为 (m,200) (m = 3683) 的矩阵 yj,并且我有一个字典,它为每个键返回一个用于 yj 的 numpy 行索引数组(对于每个键,大小数组都会改变,以防万一有人想知道)。

现在,我必须多次访问这个矩阵(大约 100 万次),并且我的代码由于索引而变慢(我已经分析了代码,这一步需要 65% 的时间)。

这是我尝试过的:

  1. 首先,使用索引进行切片:
>> %timeit yj[R_u_idx_train[1]]
10.5 µs ± 79.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

变量R_u_idx_train 是具有行索引的字典。

  1. 我认为布尔索引可能会更快:
>> yj[R_u_idx_train_mask[1]]
10.5 µs ± 159 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

R_u_idx_train_mask 是一个字典,它返回一个大小为 m 的布尔数组,其中 R_u_idx_train 给出的索引设置为 True。

  1. 我也试过np.ix_
>> cols = np.arange(0,200)
>> %timeit ix_ = np.ix_(R_u_idx_train[1], cols); yj[ix_]
42.1 µs ± 353 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
  1. 我也试过np.take
>> %timeit np.take(yj, R_u_idx_train[1], axis=0)
2.35 ms ± 88.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

虽然这看起来很棒,但事实并非如此,因为它提供了一个形状为 (R_u_idx_train[1].shape[0], R_u_idx_train[1].shape[0]) 的数组(应该是 (R_u_idx_train[1].shape[0], 200))。我想我没有正确使用该方法。

  1. 我也试过np.compress
>> %timeit np.compress(R_u_idx_train_mask[1], yj, axis=0)
14.1 µs ± 124 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
  1. 最后我尝试使用布尔矩阵进行索引
>> %timeit yj[R_u_idx_train_mask2[1]]
244 µs ± 786 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

那么,10.5 µs ± 79.7 ns per loop 是我能做的最好的吗?我可以尝试使用cython,但这似乎只是索引的大量工作......

非常感谢。

【问题讨论】:

  • 我也会计时,R_u_idx_train[1]yj[idx]。行索引的大小大概是多少?
  • 它没那么大。它有大约 6000 个条目。但是调用这么小的字典应该是微不足道的,对吧?
  • 您可以将结果映射到新的字典,从而避免像newdict = {k: yj[R_u_idx_train[k]] for k in R_u_idx_train.keys()} 那样每次查找 1 次吗?此外,如果您的字典中的键是连续数字,您可以使用 list 而不是 dict - 这会更快。
  • 哦,newdict 很聪明,没想到。不,我的键不是连续的(或者它们不需要是连续的),所以我认为 dict 是最好的数据结构。
  • 通过一些快速测试,时间大致与索引返回的元素数量有关。 yj[idx,:],即。 len(idx)*200。虽然与在 Python 代码中逐个选择元素相比要快,但它仍然不能是瞬时的。结果是一个带有新数据缓冲区的新数组(不是view)。 yj[idx] 是进行numpy 索引的最直接方式。

标签: python arrays numpy indexing numpy-indexing


【解决方案1】:

V.Ayrat 在 cmets 中给出了一个非常聪明的解决方案。

>> newdict = {k: yj[R_u_idx_train[k]] for k in R_u_idx_train.keys()}
>> %timeit newdict[1]
202 ns ± 6.7 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

不管怎样,知道是否有办法使用numpy 来加速它可能还是很酷的!

【讨论】:

    猜你喜欢
    • 2019-06-19
    • 1970-01-01
    • 1970-01-01
    • 2022-10-15
    • 2021-02-19
    • 1970-01-01
    • 2023-03-03
    • 2018-07-18
    • 2018-02-11
    相关资源
    最近更新 更多