【问题标题】:How to speed up nested loops for groupby multiindex如何加速 groupby 多索引的嵌套循环
【发布时间】:2024-01-15 21:09:01
【问题描述】:

我有两个 Multiindex 数据框,即 panel1 和 panel2:两者都有相同的 0 级索引 - 日期,但不同的 1 级索引;请参阅下面的示例代码:

# panel1:
idx1 = pd.MultiIndex.from_product([['2017-05-02', '2017-05-03', '2017-05-04'],['id1', 'id2', 'id3']],names=['Dates', 'id'])
panel1=pd.DataFrame(np.random.randn(9,2), index=idx1,columns=['ytm','mat'])
# panel2:
idx2 = pd.MultiIndex.from_product([['2017-05-02', '2017-05-03', '2017-05-04'],['0.5', '1.5', '2.5']],names=['Dates', 'yr'])
panel2=pd.DataFrame(np.random.randn(9), index=idx2,columns=['curve'])

我想按日期(0 级索引)循环遍历两个面板。因此,对于每一天(例如“2017-05-02”),我在 yr 列(panel2)中搜索每个 id/row(panel1)的 ma​​t ,如果有匹配,我想获取对应的 curve 值(panel2)并将其添加为 panel1 中的新列(名为 CDB)。

我目前的代码如下:

group1=panel1.groupby(level=0)
group2=panel2.groupby(level=0)

lst=[]
for ytm in group1:              # loop over each day
    for yr in group2:           # loop over each day
        df_ytm=ytm[1]           # get df of id, yt & mat
        df_ytm=df_ytm.assign(CDB=np.nan)      # add a col of nan, later will be replaced by matched curve values
        df_curve=yr[1].reset_index()          # need get rid of index to match yr with t_mat
        df_curve.yr=df_curve.yr.astype(float) 
        for i in range(df_ytm.shape[0]):      # loop over each row
            if (df_ytm.iloc[i,1]==df_curve.yr).any()==True:      # search if each 'mat' value in 'yr' column
                df_ytm.iloc[i,2]=df_curve[df_curve.yr.isin([df_ytm.t_mat[i]])].curve.values   # if matched, set 'CDB' as curve value
    lst.append(df_ytm)      # need get modified 'df_ytm' (with matched 'CDB')  

代码在我尝试使用小样本时有效,但我有一个巨大的面板 1(大小为 800 天乘以 10000 个 ID)和大面板 2。所以代码已经运行了24小时以上。

我想知道如何重写代码(使用可能的矢量化)以加快速度?

任何 cmets 将不胜感激!

【问题讨论】:

    标签: python pandas performance vectorization multi-index


    【解决方案1】:

    如果我理解正确,您需要从Dates 索引和mat 列构造新的MultiIndex,并为此索引获取curve 的值。

    import pandas as pd
    import numpy as np
    
    np.random.seed(12)
    idx1 = pd.MultiIndex.from_product(
        [["2017-05-02", "2017-05-03", "2017-05-04"], ["id1", "id2", "id3"]],
        names=["Dates", "id"],
    )
    panel1 = pd.DataFrame(
        np.random.randint(3, size=(9, 2)), index=idx1, columns=["ytm", "mat"]
    )
    idx2 = pd.MultiIndex.from_product(
        [["2017-05-02", "2017-05-03", "2017-05-04"], ["0", "1", "2"]], names=["Dates", "yr"]
    )
    panel2 = pd.DataFrame(np.random.randint(3, size=9), index=idx2, columns=["curve"])
    print(panel1)
    #                 ytm  mat
    # Dates      id
    # 2017-05-02 id1    2    1
    #            id2    1    2
    #            id3    0    0
    # 2017-05-03 id1    2    1
    #            id2    0    1
    #            id3    1    1
    # 2017-05-04 id1    2    2
    #            id2    2    0
    #            id3    1    0
    print(panel2)
    #                curve
    # Dates      yr
    # 2017-05-02 0       0
    #            1       1
    #            2       2
    # 2017-05-03 0       1
    #            1       2
    #            2       0
    # 2017-05-04 0       1
    #            1       2
    #            2       0
    panel1["CDM"] = panel2.loc[
        pd.MultiIndex.from_arrays(
            [panel1.index.get_level_values(0), panel1.mat.astype(str).rename("yr")]
        )
    ].to_numpy()
    print(panel1)
    #                 ytm  mat  CDM
    # Dates      id
    # 2017-05-02 id1    2    1    1
    #            id2    1    2    2
    #            id3    0    0    0
    # 2017-05-03 id1    2    1    2
    #            id2    0    1    2
    #            id3    1    1    2
    # 2017-05-04 id1    2    2    0
    #            id2    2    0    1
    #            id3    1    0    1
    

    编辑

    matyr 比较为浮点数并使用.reindex 而不是.loc

    import pandas as pd
    import numpy as np
    
    np.random.seed(12)
    idx1 = pd.MultiIndex.from_product(
        [["2017-05-02", "2017-05-03", "2017-05-04"], ["id1", "id2", "id3"]],
        names=["Dates", "id"],
    )
    panel1 = pd.DataFrame(
        np.random.randint(3, size=(9, 2)), index=idx1, columns=["ytm", "mat"]
    )
    panel1.iloc[0, 1] = np.nan
    idx2 = pd.MultiIndex.from_product(
        [["2017-05-02", "2017-05-03", "2017-05-04"], ["0", "1", "2"]], names=["Dates", "yr"]
    )
    panel2 = pd.DataFrame(np.random.randint(3, size=9), index=idx2, columns=["curve"])
    panel2 = panel2.rename(float, level=1)
    print(panel1)
    #                 ytm  mat
    # Dates      id
    # 2017-05-02 id1    2  NaN
    #            id2    1  2.0
    #            id3    0  0.0
    # 2017-05-03 id1    2  1.0
    #            id2    0  1.0
    #            id3    1  1.0
    # 2017-05-04 id1    2  2.0
    #            id2    2  0.0
    #            id3    1  0.0
    print(panel2)
    #                 curve
    # Dates      yr
    # 2017-05-02 0.0      0
    #            1.0      1
    #            2.0      2
    # 2017-05-03 0.0      1
    #            1.0      2
    #            2.0      0
    # 2017-05-04 0.0      1
    #            1.0      2
    #            2.0      0
    panel1["CDM"] = panel2.reindex(
        pd.MultiIndex.from_arrays(
            [panel1.index.get_level_values(0), panel1.mat.rename("yr")]
        )
    ).to_numpy()
    print(panel1)
    #                 ytm  mat  CDM
    # Dates      id
    # 2017-05-02 id1    2  NaN  NaN
    #            id2    1  2.0  2.0
    #            id3    0  0.0  0.0
    # 2017-05-03 id1    2  1.0  2.0
    #            id2    0  1.0  2.0
    #            id3    1  1.0  2.0
    # 2017-05-04 id1    2  2.0  0.0
    #            id2    2  0.0  1.0
    #            id3    1  0.0  1.0
    

    【讨论】:

    • 非常感谢您的解决方案!它似乎有效,但代码panel1.mat.astype(str) 出现了一个问题。由于panel1.mat 是带有 2 个小数位的浮点数据(例如 3.60),它故意设置为与 panel2.yr(带有 2 个小数点的浮点数)匹配,因此,panel1.mat.astype(str) 将 3.60 转换为无法匹配的 3.6在panel2.yr 中格式为 (3.60)。我想知道是否有更好的方法来处理这个问题?或者需要格式化panel2.yr。 @V.Ayrat
    • 只是跟进,还有一个未来警告返回 "main:2: FutureWarning: Passing list-likes to .loc or [] with any missing label will raise KeyError 在未来,你可以使用 .reindex() 作为替代。"。我想知道这会不会是一个问题?
    • 我认为使用浮点格式更好。你可以做panel2 = panel2.rename(float, level=1)。但这提出了浮点数比较的问题。如果为yrmat 分配了相同的值,则完全匹配将起作用,但如果有一些中间计算,这可能会稍微改变这些值并且它们将不匹配......如果有缺失值,那就更好了使用.reindex(),因为它建议禁止警告。
    • 我想知道在这种情况下如何正确使用.reindexmat 列确实有 nan,但 yr 没有丢失数据。再次感谢!
    • 我想panel1.mat.astype(str) 会在mat 格式为浮点数(例如3.60)但yr 格式为字符串(例如3.6)时引发不匹配...?简而言之,我尝试将yr 设置为浮点数;它导致所有级别的 Multiindex 不匹配,因此,使原始建议不起作用...相反,首先将 yr 设置为 float 并再次将其设置为 str,匹配过程有效。另外,我想知道如何在这里正确使用.reindex? @V.Ayrat
    【解决方案2】:

    为了生成我的代码的任何非空且可重复的结果, 我稍微改变了两个面板的创建方式:

    np.random.seed(0)
    idx1 = pd.MultiIndex.from_product([['2017-05-02', '2017-05-03', '2017-05-04'],
        ['id1', 'id2', 'id3']], names=['Dates', 'id'])
    panel1 = pd.DataFrame({'ytm': np.random.randn(9),
        'mat': [0.5, 0.82, 1.06, -0.27, 1.5, 0.59, 0.62, 1.89, 2.5]}, index=idx1)
    idx2 = pd.MultiIndex.from_product([['2017-05-02', '2017-05-03', '2017-05-04'],
        [0.5, 1.5, 2.5]], names=['Dates', 'yr'])
    panel2 = pd.DataFrame(np.random.randn(9), index=idx2, columns=['curve'])
    

    变化包括:

    • np.random.seed - 获得可重复的结果。
    • 只有 panel1ytm 列被创建为随机数。为了 为了在 mat 中有一些匹配的值,我将预定义的值放在那里, 为每个日期提供一个与 yr 匹配的匹配项。
    • idx2 的级别 1 是 float 类型。您的示例包括字符串, 这显然不等于 mat 值。

    我还假设对于 panel1 中的每个组,查找匹配项应该 在 panel2 中以 相同日期 的行中执行(不在 所有日期的组)。

    要生成结果(CDB 列),请执行以下操作:

    1. 为当前组定义一个生成CDB列的函数 行(每个日期):

       def getCDB(grp):
           cdb = panel2.xs(grp.index[0][0], level=0).reindex(grp.mat).curve
           return pd.Series(cdb.values, index=grp.index)
      
    2. 然后应用它并将结果保存在新列中:

       panel1['CDB'] = panel1.groupby(level=0).apply(getCDB)\
           .reset_index(level=0, drop=True)
      

    对于我的输入数据,结果是:

                         ytm   mat       CDB
    Dates      id                           
    2017-05-02 id1  1.764052  0.50  0.410599
               id2  0.400157  0.82       NaN
               id3  0.978738  1.06       NaN
    2017-05-03 id1  2.240893 -0.27       NaN
               id2  1.867558  1.50  0.121675
               id3 -0.977278  0.59       NaN
    2017-05-04 id1  0.950088  0.62       NaN
               id2 -0.151357  1.89       NaN
               id3 -0.103219  2.50 -0.205158
    

    【讨论】: