【问题标题】:Pandas: increase speed of rolling window (apply a custom function)熊猫:提高滚动窗口的速度(应用自定义功能)
【发布时间】:2019-09-11 20:02:17
【问题描述】:

我正在使用此代码使用滚动窗口在我的数据框上应用函数 (funcX)。主要问题是这个数据框 (data) 的大小非常大,我正在寻找一种更快的方法来完成这项任务。

import numpy as np

def funcX(x):
    x = np.sort(x)
    xd = np.delete(x, 25)
    med = np.median(xd)
    return (np.abs(x - med)).mean() + med

med_out = data.var1.rolling(window = 51, center = True).apply(funcX, raw = True)

使用这个函数的唯一原因是计算出的中位数是去掉中间值后的中位数。所以在滚动窗口末尾添加.median()是不同的。

【问题讨论】:

    标签: python performance dataframe median rolling-computation


    【解决方案1】:

    要有效,窗口算法必须链接两个重叠窗口的结果。

    这里,用:med0 中位数,medx \ med0 的中位数,xl 之前的元素,medxg 之后的元素在med 之后的元素中,funcX(x) 可以是被视为:

    <|x-med|> + med = [sum(xg) - sum(xl) - |med0-med|] / windowsize + med  
    

    因此,一个想法是维护一个表示已排序的当前窗口sum(xg)sum(xl) 的缓冲区。使用 Numba 即时编译,这里会产生非常好的性能。

    首先是缓冲区管理:

    init 对第一个窗口进行排序并计算左(xls) 和右(xgs) 和。

    import numpy as np
    import numba
    windowsize = 51 #odd, >1
    halfsize = windowsize//2
    
    @numba.njit
    def init(firstwindow):
        buffer = np.sort(firstwindow)
        xls = buffer[:halfsize].sum()
        xgs = buffer[-halfsize:].sum()   
        return buffer,xls,xgs
    

    shift 是线性部分。它更新缓冲区,保持它的排序。 np.searchsorted 计算 O(log(windowsize)) 中的插入和删除位置。这是技术问题,因为 xin&lt;xoutxout&lt;xin 不是对称情况。

    @numba.njit
    def shift(buffer,xin,xout):
        i_in = np.searchsorted(buffer,xin) 
        i_out = np.searchsorted(buffer,xout)
        if xin <= xout :
            buffer[i_in+1:i_out+1] = buffer[i_in:i_out] 
            buffer[i_in] = xin                        
        else:
            buffer[i_out:i_in-1] = buffer[i_out+1:i_in]                      
            buffer[i_in-1] = xin
        return i_in, i_out
    

    update 更新缓冲区和左右部分的总和。这是技术问题,因为 xin&lt;xoutxout&lt;xin 不是对称情况。

    @numba.njit
    def update(buffer,xls,xgs,xin,xout):
        xl,x0,xg = buffer[halfsize-1:halfsize+2]
        i_in,i_out = shift(buffer,xin,xout)
    
        if i_out < halfsize:
            xls -= xout
            if i_in <= halfsize:
                xls += xin
            else:    
                xls += x0
        elif i_in < halfsize:
            xls += xin - xl
    
        if i_out > halfsize:
            xgs -= xout
            if i_in > halfsize:
                xgs += xin
            else:    
                xgs += x0
        elif i_in > halfsize+1:
            xgs += xin - xg
    
        return buffer, xls, xgs
    

    func 相当于缓冲区上的原始funcXO(1).

    @numba.njit       
    def func(buffer,xls,xgs):
        med0 = buffer[halfsize]
        med  = (buffer[halfsize-1] + buffer[halfsize+1])/2
        if med0 > med:
            return (xgs-xls+med0-med) / windowsize + med
        else:               
            return (xgs-xls+med-med0) / windowsize + med    
    

    med 是全局函数。 O(data.size * windowsize).

    @numba.njit
    def med(data):
        res = np.full_like(data, np.nan)
        state = init(data[:windowsize])
        res[halfsize] = func(*state)
        for i in range(windowsize, data.size):
            xin,xout = data[i], data[i - windowsize]
            state = update(*state, xin, xout)
            res[i-halfsize] = func(*state)
        return res 
    

    性能:

    import pandas
    data=pandas.DataFrame(np.random.rand(10**5))
    
    %time res1=data[0].rolling(window = windowsize, center = True).apply(funcX, raw = True)
    Wall time: 10.8 s
    
    res2=med(data[0].values)
    
    np.allclose((res1-res2)[halfsize:-halfsize],0)
    Out[112]: True
    
    %timeit res2=med(data[0].values)
    40.4 ms ± 462 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    它快了 ~ 250 倍,窗口大小 = 51。一小时变成了 15 秒。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-09-23
      • 2017-03-30
      • 2019-07-02
      • 1970-01-01
      • 1970-01-01
      • 2015-09-30
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多