【问题标题】:Faster alternatives to numpy.argmax/argmin which is slow速度较慢的 numpy.argmax/argmin 的更快替代方案
【发布时间】:2014-11-08 11:26:00
【问题描述】:

我在 Python 中使用了很多 argmin 和 argmax。

很遗憾,这个功能很慢。

我做了一些搜索,我能找到的最好的在这里:

http://lemire.me/blog/archives/2008/12/17/fast-argmax-in-python/

def fastest_argmax(array):
    array = list( array )
    return array.index(max(array))

不幸的是,这个解决方案仍然只有 np.max 的一半,我想我应该能够找到与 np.max 一样快的东西。

x = np.random.randn(10)
%timeit np.argmax( x )
10000 loops, best of 3: 21.8 us per loop

%timeit fastest_argmax( x )    
10000 loops, best of 3: 20.8 us per loop

作为说明,我将其应用于 Pandas DataFrame Groupby

例如

%timeit grp2[ 'ODDS' ].agg( [ fastest_argmax ] )
100 loops, best of 3: 8.8 ms per loop

%timeit grp2[ 'ODDS' ].agg( [ np.argmax ] )
100 loops, best of 3: 11.6 ms per loop

数据如下所示:

grp2[ 'ODDS' ].head()
Out[60]: 
EVENT_ID   SELECTION_ID        
104601100  4367029       682508    3.05
                         682509    3.15
                         682510    3.25
                         682511    3.35
           5319660       682512    2.04
                         682513    2.08
                         682514    2.10
                         682515    2.12
                         682516    2.14
           5510310       682520    4.10
                         682521    4.40
                         682522    4.50
                         682523    4.80
                         682524    5.30
           5559264       682526    5.00
                         682527    5.30
                         682528    5.40
                         682529    5.50
                         682530    5.60
           5585869       682533    1.96
                         682534    1.97
                         682535    1.98
                         682536    2.02
                         682537    2.04
           6064546       682540    3.00
                         682541    2.74
                         682542    2.76
                         682543    2.96
                         682544    3.05
104601200  4916112       682548    2.64
                         682549    2.68
                         682550    2.70
                         682551    2.72
                         682552    2.74
           5315859       682557    2.90
                         682558    2.92
                         682559    3.05
                         682560    3.10
                         682561    3.15
           5356995       682564    2.42
                         682565    2.44
                         682566    2.48
                         682567    2.50
                         682568    2.52
           5465225       682573    1.85
                         682574    1.89
                         682575    1.91
                         682576    1.93
                         682577    1.94
           5773661       682588    5.00
                         682589    4.40
                         682590    4.90
                         682591    5.10
           6013187       682592    5.00
                         682593    4.20
                         682594    4.30
                         682595    4.40
                         682596    4.60
104606300  2489827       683438    4.00
                         683439    3.90
                         683440    3.95
                         683441    4.30
                         683442    4.40
           3602724       683446    2.16
                         683447    2.32
Name: ODDS, Length: 65, dtype: float64

【问题讨论】:

  • 我建议查看 numpy 源代码,了解他们是如何做到的。这是一个链接:github.com/numpy/numpy/blob/… 尝试模仿他们的方法。
  • “效率低下”甚至没有开始描述您所谓的“fastest_argmax”。
  • 在我的笔记本电脑 %timeit np.argmax( x )x.shape == 10 上只花了 1 我们,你的 numpy 版本是什么?
  • 您链接的代码主要是关于正确处理axis 参数。实际的argmin/maxhere 定义的函数执行。请注意,它是模板代码,有关语法说明,请参阅here
  • ivan - 你对更高效的版本有什么建议吗?

标签: python numpy


【解决方案1】:

事实证明,np.argmax 非常快,但 使用原生 numpy 数组。对于国外数据,几乎所有的时间都花在了转化上:

In [194]: print platform.architecture()
('64bit', 'WindowsPE')

In [5]: x = np.random.rand(10000)
In [57]: l=list(x)
In [123]: timeit numpy.argmax(x)
100000 loops, best of 3: 6.55 us per loop
In [122]: timeit numpy.argmax(l)
1000 loops, best of 3: 729 us per loop
In [134]: timeit numpy.array(l)
1000 loops, best of 3: 716 us per loop

我称您的函数“效率低下”,因为它首先将所有内容都转换为列表,然后对其进行 2 次迭代(实际上是 3 次迭代 + 列表构造)。

我打算这样建议只迭代一次的东西:

def imax(seq):
    it=iter(seq)
    im=0
    try: m=it.next()
    except StopIteration: raise ValueError("the sequence is empty")
    for i,e in enumerate(it,start=1):
        if e>m:
            m=e
            im=i
    return im

但是,事实证明您的版本更快,因为它迭代了很多次,但它是在 C 代码而不是 Python 代码中完成的。 C 的速度要快得多 - 即使考虑到大量时间也用于转换这一事实:

In [158]: timeit imax(x)
1000 loops, best of 3: 883 us per loop
In [159]: timeit fastest_argmax(x)
1000 loops, best of 3: 575 us per loop

In [174]: timeit list(x)
1000 loops, best of 3: 316 us per loop
In [175]: timeit max(l)
1000 loops, best of 3: 256 us per loop
In [181]: timeit l.index(0.99991619010758348)  #the greatest number in my case, at index 92
100000 loops, best of 3: 2.69 us per loop

因此,进一步加快此过程的关键知识是了解序列中的数据本身是哪种格式(例如,您是否可以省略转换步骤或使用/编写该格式的其他功能)。

顺便说一句,使用aggregate(max_fn) 代替agg([max_fn]) 可能会加快速度。

【讨论】:

  • 感谢伊万的帮助!
【解决方案2】:

对于那些返回第一个最小值的索引的短 numpy-free sn-p:

def argmin(a):
    return min(range(len(a)), key=lambda x: a[x])
a = [6, 5, 4, 1, 1, 3, 2]
argmin(a)  # returns 3

【讨论】:

  • mhm,只要把min改成max就可以转成argmax
  • 嗯,可以
  • 非常优雅的解决方案。谢谢先生。
【解决方案3】:

你能发布一些代码吗?这是我电脑上的结果:

x = np.random.rand(10000)
%timeit np.max(x)
%timeit np.argmax(x)

输出:

100000 loops, best of 3: 7.43 µs per loop
100000 loops, best of 3: 11.5 µs per loop

【讨论】:

  • 你的终端是什么?普通的控制台字体无法打印 mu 字母 AFAIK。
  • @ivan_pozdeev - 大多数现代控制台模拟器至少有一些(通常是完整的)unicode 支持(甚至是 xterm)。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2020-06-30
  • 2012-07-14
相关资源
最近更新 更多