【问题标题】:How to apply OLS from statsmodels to groupby如何将 OLS 从 statsmodels 应用到 groupby
【发布时间】:2015-12-07 23:52:54
【问题描述】:

我按月对产品运行 OLS。虽然这适用于单个产品,但我的数据框包含许多产品。如果我创建一个 groupby 对象,OLS 会报错。

linear_regression_df:
  product_desc  period_num    TOTALS  
0    product_a     1          53  
3    product_a     2          52 
6    product_a     3          50 
1    product_b     1          44 
4    product_b     2          43 
7    product_b     3          41 
2    product_c     1          36   
5    product_c     2          35 
8    product_c     3          34 


from pandas import DataFrame, Series
import statsmodels.api as sm    

linear_regression_grouped = linear_regression_df.groupby(['product_desc'])
X = linear_regression_grouped['period_num'] 
y = linear_regression_grouped['TOTALS']

model = sm.OLS(y, X)
results = model.fit()

我在 sm.OLS() 行收到此错误:

ValueError: unrecognized data structures: <class 'pandas.core.groupby.SeriesGroupBy'>

那么我怎样才能通过我的数据框并为每个 product_desc 应用 sm.OLS()?

【问题讨论】:

    标签: python pandas statsmodels


    【解决方案1】:

    你可以做这样的事情......

    import pandas as pd
    import statsmodels.api as sm
    
    for products in linear_regression_df.product_desc.unique():
        tempdf = linear_regression_df[linear_regression_df.product_desc == products]
        X = tempdf['period_num']
        y = tempdf['TOTALS']
    
        model = sm.OLS(y, X)
        results = model.fit()
    
        print results.params #  Or whatever summary info you want
    

    【讨论】:

      【解决方案2】:

      使用get_group 获取每个单独的组并对每个组执行 OLS 模型:

      for group in linear_regression_grouped.groups.keys():
          df= linear_regression_grouped.get_group(group)
          X = df['period_num'] 
          y = df['TOTALS']
          model = sm.OLS(y, X)
          results = model.fit()
          print results.summary()
      

      但在实际情况下,您还希望有截距项,因此模型的定义应略有不同:

      for group in linear_regression_grouped.groups.keys():
          df= linear_regression_grouped.get_group(group)
          df['constant']=1
          X = df[['period_num','constant']]
          y = df['TOTALS']
          model = sm.OLS(y,X)
          results = model.fit()
          print results.summary()
      

      结果(有拦截和没有拦截)当然非常不同。

      【讨论】:

      • 我发现这个答案非常有用。我试图实现它并得到一个AttributeError: 'DataFrame' object has no attribute 'get_group' 错误虽然..
      猜你喜欢
      • 2021-10-05
      • 2020-11-23
      • 2023-02-15
      • 2014-07-28
      • 2013-05-01
      • 2022-06-08
      • 2018-10-17
      • 2017-11-02
      • 2021-09-20
      相关资源
      最近更新 更多