【问题标题】:Find position of maximum per unique bin (binargmax)查找每个唯一 bin 的最大值位置 (binargmax)
【发布时间】:2018-08-24 14:48:18
【问题描述】:

设置

假设我有

bins = np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2])
vals = np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])
k = 3

我需要bins 中唯一 bin 的最大值位置。

# Bin == 0
#  ↓ ↓           ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
#  ↑ ↑           ↑
#  ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 8 and happens at position 0

(vals * (bins == 0)).argmax()

0

# Bin == 1
#      ↓ ↓         ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
#      ↑ ↑         ↑
#        ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 4 and happens at position 3

(vals * (bins == 1)).argmax()

3

# Bin == 2
#          ↓ ↓ ↓     ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
#          ↑ ↑ ↑     ↑
#                    ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 9 and happens at position 9

(vals * (bins == 2)).argmax()

9

这些函数很老套,甚至不能泛化为负值。

问题

如何使用 Numpy 以最有效的方式获取所有这些值?

我尝试过的。

def binargmax(bins, vals, k):
  out = -np.ones(k, np.int64)
  trk = np.empty(k, vals.dtype)
  trk.fill(np.nanmin(vals) - 1)

  for i in range(len(bins)):
    v = vals[i]
    b = bins[i]
    if v > trk[b]:
      trk[b] = v
      out[b] = i

  return out

binargmax(bins, vals, k)

array([0, 3, 9])

LINK TO TESTING AND VALIDATION

【问题讨论】:

  • 因此,k 始终为否。独特的垃圾箱?
  • 是的,应该和bins.max() + 1一样
  • 是否保证每个 bin 的值是唯一的?你想要所有的最大值吗?
  • 不保证,我要第一名。喜欢np.array([1, 2, 2]).argmax() 返回1 @user3483203
  • 当然...(-:对不起,我错过了。完成!

标签: python numpy


【解决方案1】:

numpy_indexed 库:

我知道这在技术上不是 numpy,但 numpy_indexed 库有一个矢量化的 group_by 函数,非常适合这个,只是想分享作为我经常使用的替代方案:

>>> import numpy_indexed as npi
>>> npi.group_by(bins).argmax(vals)
(array([0, 1, 2]), array([0, 3, 9], dtype=int64))

使用简单的pandas groupbyidxmax

df = pd.DataFrame({'bins': bins, 'vals': vals})
df.groupby('bins').vals.idxmax()

使用sparse.csr_matrix

这个选项在非常大的输入上非常快。

sparse.csr_matrix(
    (vals, bins, np.arange(vals.shape[0]+1)), (vals.shape[0], k)
).argmax(0)

# matrix([[0, 3, 9]])

性能

函数

def chris(bins, vals, k):
    return npi.group_by(bins).argmax(vals)

def chris2(df):
    return df.groupby('bins').vals.idxmax()

def chris3(bins, vals, k):
    sparse.csr_matrix((vals, bins, np.arange(vals.shape[0] + 1)), (vals.shape[0], k)).argmax(0)

def divakar(bins, vals, k):
    mx = vals.max()+1

    sidx = bins.argsort()
    sb = bins[sidx]
    sm = np.r_[sb[:-1] != sb[1:],True]

    argmax_out = np.argsort(bins*mx + vals)[sm]
    max_out = vals[argmax_out]
    return max_out, argmax_out

def divakar2(bins, vals, k):
    last_idx = np.bincount(bins).cumsum()-1
    scaled_vals = bins*(vals.max()+1) + vals
    argmax_out = np.argsort(scaled_vals)[last_idx]
    max_out = vals[argmax_out]
    return max_out, argmax_out


def user545424(bins, vals, k):
    return np.argmax(vals*(bins == np.arange(bins.max()+1)[:,np.newaxis]),axis=-1)

def user2699(bins, vals, k):
    res = []
    for v in np.unique(bins):
        idx = (bins==v)
        r = np.where(idx)[0][np.argmax(vals[idx])]
        res.append(r)
    return np.array(res)

def sacul(bins, vals, k):
    return np.lexsort((vals, bins))[np.append(np.diff(np.sort(bins)), 1).astype(bool)]

@njit
def piRSquared(bins, vals, k):
    out = -np.ones(k, np.int64)
    trk = np.empty(k, vals.dtype)
    trk.fill(np.nanmin(vals))

    for i in range(len(bins)):
        v = vals[i]
        b = bins[i]
        if v > trk[b]:
            trk[b] = v
            out[b] = i

    return out

设置

import numpy_indexed as npi
import numpy as np
import pandas as pd
from timeit import timeit
import matplotlib.pyplot as plt
from numba import njit
from scipy import sparse

res = pd.DataFrame(
       index=['chris', 'chris2', 'chris3', 'divakar', 'divakar2', 'user545424', 'user2699', 'sacul', 'piRSquared'],
       columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000],
       dtype=float
)

k = 5

for f in res.index:
    for c in res.columns:
        bins = np.random.randint(0, k, c)
        k = 5
        vals = np.random.rand(c)
        df = pd.DataFrame({'bins': bins, 'vals': vals})
        stmt = '{}(df)'.format(f) if f in {'chris2'} else '{}(bins, vals, k)'.format(f)
        setp = 'from __main__ import bins, vals, k, df, {}'.format(f)
        res.at[f, c] = timeit(stmt, setp, number=50)

ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");

plt.show()

结果

k 更大的结果(这是广播受到重创的地方):

res = pd.DataFrame(
       index=['chris', 'chris2', 'chris3', 'divakar', 'divakar2', 'user545424', 'user2699', 'sacul', 'piRSquared'],
       columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000],
       dtype=float
)

k = 500

for f in res.index:
    for c in res.columns:
        bins = np.random.randint(0, k, c)
        vals = np.random.rand(c)
        df = pd.DataFrame({'bins': bins, 'vals': vals})
        stmt = '{}(df)'.format(f) if f in {'chris2'} else '{}(bins, vals, k)'.format(f)
        setp = 'from __main__ import bins, vals, df, k, {}'.format(f)
        res.at[f, c] = timeit(stmt, setp, number=50)

ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");

plt.show()

从图中可以明显看出,当组的数量较少时,广播是一个绝妙的技巧,但是在较高的k 值下,广播的时间复杂度/内存增长过快,无法使其具有高性能。

【讨论】:

  • 介意添加我的时间吗?
  • 不错的基准!作为 numpy_indexed 的作者,让我注意到该库已优化为“numpythonic”和通用。也就是说,您的 bin 不必是从 0 开始的整数;但可以是任何类型和任何维度的 ndarray 事实上。这确实会在这里和那里增加一点开销,但如果性能是您的主要目标,那么对于这类问题,确实没有与 numba 争论。不过,拥有一个带有简单 API 的参考实现来测试您的低级代码仍然很好!
  • 非常好的使用稀疏。您为我的工具箱提供了两个好主意。
  • 您可能想在此处使用 CSR 与 CSC 稀疏矩阵进行测试。由于正在执行的操作类型,一个可能会更快。我认为论点几乎相同。我会在使用计算机时发布 CSC 解决方案。
【解决方案2】:

这里有一种方法是偏移每个组数据,以便我们可以一次性对整个数据使用argsort -

def binargmax_scale_sort(bins, vals):
    w = np.bincount(bins)
    valid_mask = w!=0
    last_idx = w[valid_mask].cumsum()-1
    scaled_vals = bins*(vals.max()+1) + vals
    #unique_bins = np.flatnonzero(valid_mask) # if needed
    return len(bins) -1 -np.argsort(scaled_vals[::-1], kind='mergesort')[last_idx]

【讨论】:

  • @piRSquared 建议更好的解决方案 - 对于 bin 不覆盖0-bins.max() 范围的情况,输出唯一的 bin 不是更好吗?
  • @piRSquared 是的,我正在使用bins, vals = gen_arrays(5000, 10000) 进行测试,而我修改后的解决方案仅涵盖独特的解决方案,而不是整个范围,因此与binargmax 不匹配。
  • 我期待你的 scaled_vals,因为我看到你以前使用过它。使用 cumsum 推导出 last_idx 以预期对 argsort 的结果进行切片!?杰出的!虽然我讨厌这种人,但我不能否认它的聪明才智。
  • @piRSquared 发现我可以使用 bincount 来获取每组的最后一个索引是我必须承认的 eureka 小时刻之一。有趣的问答肯定是这个。获得 argmax(第一个)需要更多的脑力劳动 :)
【解决方案3】:

好的,这是我的线性时间条目,仅使用索引和np.(max|min)inum.at。它假设 bin 从 0 上升到 max(bins)。

def via_at(bins, vals):
    max_vals = np.full(bins.max()+1, -np.inf)
    np.maximum.at(max_vals, bins, vals)
    expanded = max_vals[bins]
    max_idx = np.full_like(max_vals, np.inf)
    np.minimum.at(max_idx, bins, np.where(vals == expanded, np.arange(len(bins)), np.inf))
    return max_vals, max_idx

【讨论】:

    【解决方案4】:

    这个怎么样:

    >>> import numpy as np
    >>> bins = np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2])
    >>> vals = np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])
    >>> k = 3
    >>> np.argmax(vals*(bins == np.arange(k)[:,np.newaxis]),axis=-1)
    array([0, 3, 9])
    

    【讨论】:

    • 这很聪明(-:时间复杂度和内存需求会随着k大而爆炸(我认为)。
    • @piRSquared,我已经为此设置了一些基准。使用 30 个左右的 bin 效果很好,性能下降 1000 次。只有 3 个垃圾箱,它是迄今为止最快的答案。
    • 我也在做同样的事情。这应该是线性的,长度为vals。当我应用 Numba 的 njit 时,我的初始方法是最快的。我会展示它。我想要一个 O(n) Numpy 方法。这确实很接近。
    【解决方案5】:

    如果您追求可读性,这可能不是最好的解决方案,但我认为它有效

    def binargsort(bins,vals):
        s = np.lexsort((vals,bins))
        s2 = np.sort(bins)
        msk = np.roll(s2,-1) != s2
        # or use this for msk, but not noticeably better for performance:
        # msk = np.append(np.diff(np.sort(bins)),1).astype(bool)
        return s[msk]
    
    array([0, 3, 9])
    

    解释

    lexsort按照bins的排序顺序对vals的索引进行排序,然后按照vals的排序:

    >>> np.lexsort((vals,bins))
    array([7, 1, 0, 8, 2, 3, 4, 5, 6, 9])
    

    因此,您可以通过 bins 的排序位置从一个索引到下一个索引进行屏蔽:

    >>> np.sort(bins)
    array([0, 0, 0, 1, 1, 1, 2, 2, 2, 2])
    
    # Find where sorted bins end, use that as your mask on the `lexsort`
    >>> np.append(np.diff(np.sort(bins)),1)
    array([0, 0, 1, 0, 0, 1, 0, 0, 0, 1])
    
    >>> np.lexsort((vals,bins))[np.append(np.diff(np.sort(bins)),1).astype(bool)]
    array([0, 3, 9])
    

    【讨论】:

    • 查看我的链接的验证部分。这是返回最大值的最后一个位置。
    • hmm...我编辑的解决方案(使用s2 = np.sort(bins); msk = np.roll(s2,-1) != s2)通过了前两个验证,但没有通过第三个...不确定发生了什么,试图弄清楚。
    【解决方案6】:

    这是一个有趣的小问题。我的方法是根据bins 中的值获取vals 的索引。使用where 获取索引为True 的点,结合vals 中这些点上的argmax 得到结果值。

    def binargmaxA(bins, vals):
        res = []
        for v in unique(bins):
            idx = (bins==v)
            r = where(idx)[0][argmax(vals[idx])]
            res.append(r)
        return array(res)
    

    可以通过使用range(k) 来删除对unique 的调用以获取可能的bin 值。这加快了速度,但随着 k 大小的增加,性能仍然很差。

    def binargmaxA2(bins, vals, k):
        res = []
        for v in range(k):
            idx = (bins==v)
            r = where(idx)[0][argmax(vals[idx])]
            res.append(r)
        return array(res)
    

    最后一次尝试,比较每个值会大大减慢速度。此版本计算排序后的值数组,而不是对每个唯一值进行比较。好吧,它实际上计算了排序后的索引,并且只在需要时获取排序后的值,因为这样可以避免一次将 val 加载到内存中。性能仍会随着 bin 数量的增加而变化,但比以前慢了很多。

    def binargmaxB(bins, vals):
        idx = argsort(bins)   # Find sorted indices
        split = r_[0, where(diff(bins[idx]))[0]+1, len(bins)]  # Compute where values start in sorted array
        newmax = [argmax(vals[idx[i1:i2]]) for i1, i2 in zip(split, split[1:])]  # Find max for each value in sorted array
        return idx[newmax +split[:-1]] # Convert to indices in unsorted array
    

    基准

    以下是其他答案的一些基准。

    3000 个元素

    使用更大的数据集 (bins = randint(0, 30, 3000); vals = randn(3000); k=30;)

    • 171us Divakar 的 binargmax_scale_sort2
    • 209us 这个答案,B 版
    • 281us binargmax_scale_sort by Divakar
    • 329us用户545424的广播版本
    • 399us 这个答案,A 版
    • 416us 由 sacul 回答,使用 lexsort
    • 899us piRsquared 提供的参考代码

    30000 个元素

    还有一个更大的数据集 (bins = randint(0, 30, 30000); vals = randn(30000); k=30)。令人惊讶的是,这并没有改变解决方案之间的相对性能。

    • 1.27ms 这个答案,B 版
    • 2.01ms divakar 的 binargmax_scale_sort2
    • 2.38ms user545424 的广播版本
    • 2.68ms 这个答案,A 版
    • 5.71ms 由 sacul 回答,使用 lexsort
    • 9.12ms piRSquared 提供的参考代码

    编辑我没有随着可能的 bin 值数量的增加而更改 k,因为我已经修复了基准更加均匀。

    1000 个 bin 值

    增加唯一 bin 值的数量也可能对性能产生影响。 Divakar 和 sacul 的解决方案大多不受影响,而其他解决方案则具有相当大的影响。 bins = randint(0, 1000, 30000); vals = randn(30000); k = 1000

    • 1.99ms Divakar 的 binargmax_scale_sort2
    • 3.48ms 这个答案,B 版
    • 6.15ms 由 sacul 回答,使用 lexsort
    • 10.6ms piRsquared 提供的参考代码
    • 27.2ms 这个答案,A 版
    • 129ms 由 user545424 播放的版本

    编辑包括问题中参考代码的基准,它具有惊人的竞争力,尤其是在具有更多 bin 的情况下。

    【讨论】:

      【解决方案7】:

      我知道你说过要使用 Numpy,但如果 Pandas 是可以接受的:

      import numpy as np; import pandas as pd;
      (pd.DataFrame(
          {'bins':np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2]),
           'values':np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])}) 
      .groupby('bins')
      .idxmax())
      
            values
      bins        
      0          0
      1          3
      2          9
      

      【讨论】:

        猜你喜欢
        • 2019-11-10
        • 2011-04-16
        • 1970-01-01
        • 1970-01-01
        • 2021-01-19
        • 1970-01-01
        • 1970-01-01
        • 2017-06-11
        • 2021-02-03
        相关资源
        最近更新 更多