【问题标题】:Adding scikit-learn (sklearn) prediction to pandas data frame将 scikit-learn (sklearn) 预测添加到 pandas 数据帧
【发布时间】:2016-02-09 06:12:47
【问题描述】:

我正在尝试将 sklearn 预测添加到 pandas 数据帧,以便对预测进行全面评估。相关代码如下:

clf = linear_model.LinearRegression()
clf.fit(Xtrain,ytrain)
ypred = pd.DataFrame({'pred_lin_regr': pd.Series(clf.predict(Xtest))})

数据框如下所示:

Xtest

       axial_MET  cos_theta_r1  deltaE_abs  lep1_eta   lep1_pT  lep2_eta  
8000   1.383026      0.332365    1.061852  0.184027  0.621598 -0.316297   
8001  -1.054412      0.046317    1.461788 -1.141486  0.488133  1.011445   
8002   0.259077      0.429920    0.769219  0.631206  0.353469  1.027781   
8003  -0.096647      0.066200    0.411222 -0.867441  0.856115 -1.357888   
8004   0.145412      0.371409    1.111035  1.374081  0.485231  0.900024   

ytest

8000    1
8001    0
8002    0
8003    0
8004    0

ypred

        pred_lin_regr
0       0.461636
1       0.314448
2       0.363751
3       0.291858
4       0.416056

连接 Xtest 和 ytest 可以正常工作:

df_total = pd.concat([Xtest, ytest], axis=1)

但是事件信息在 ypred 上丢失了。

什么是必须的类似 python/pandas/numpy 的方式来做到这一点?

我正在使用以下版本:

argparse==1.2.1
cycler==0.9.0
decorator==4.0.4
ipython==4.0.0
ipython-genutils==0.1.0
matplotlib==1.5.0
nose==1.3.7
numpy==1.10.1
pandas==0.17.0
path.py==8.1.2
pexpect==4.0.1
pickleshare==0.5
ptyprocess==0.5
py==1.4.30
pyparsing==2.0.5
pytest==2.8.2
python-dateutil==2.4.2
pytz==2015.7
scikit-learn==0.16.1
scipy==0.16.1
simplegeneric==0.8.1
six==1.10.0
sklearn==0.0
traitlets==4.0.0
wsgiref==0.1.2

我尝试了以下方法:

df_total["pred_lin_regr"] = clf.predict(Xtest) 

似乎可以完成这项工作,但我认为我无法确定事件是否正确匹配

【问题讨论】:

    标签: python numpy pandas scikit-learn


    【解决方案1】:

    你的第二行是正确的,df_total["pred_lin_regr"] = clf.predict(Xtest),它更有效。

    在其中,您将获取clf.predict() 的输出,恰好是array,并将其添加到数据帧中。您从数组本身接收到的输出是为了匹配Xtest,因为是这种情况,将它添加到numpy数组将不会改变或改变它顺序。

    这是来自 example 的一个小证明:

    服用以下部分:

    import numpy as np
    
    import pandas as pd
    from sklearn import datasets, linear_model
    
    # Load the diabetes dataset
    diabetes = datasets.load_diabetes()
    
    # Use only one feature
    diabetes_X = diabetes.data[:, np.newaxis, 2]
    
    # Split the data into training/testing sets
    diabetes_X_train = diabetes_X[:-20]
    diabetes_X_test = diabetes_X[-20:]
    
    # Split the targets into training/testing sets
    diabetes_y_train = diabetes.target[:-20]
    diabetes_y_test = diabetes.target[-20:]
    
    # Create linear regression object
    regr = linear_model.LinearRegression()
    
    # Train the model using the training sets
    regr.fit(diabetes_X_train, diabetes_y_train)
    
    print(regr.predict(diabetes_X_test))
    
    df = pd.DataFrame(regr.predict(diabetes_X_test))
    
    print(df)
    

    第一个 print() 函数将按预期为我们提供一个 numpy 数组:

    [ 225.9732401   115.74763374  163.27610621  114.73638965  120.80385422
      158.21988574  236.08568105  121.81509832   99.56772822  123.83758651
      204.73711411   96.53399594  154.17490936  130.91629517   83.3878227
      171.36605897  137.99500384  137.99500384  189.56845268   84.3990668 ]
    

    该顺序与第二个 print() 函数相同,我们在其中将结果添加到数据框:

                 0
    0   225.973240
    1   115.747634
    2   163.276106
    3   114.736390
    4   120.803854
    5   158.219886
    6   236.085681
    7   121.815098
    8    99.567728
    9   123.837587
    10  204.737114
    11   96.533996
    12  154.174909
    13  130.916295
    14   83.387823
    15  171.366059
    16  137.995004
    17  137.995004
    18  189.568453
    19   84.399067
    

    重新运行部分测试的代码,将给我们同样的有序结果:

    print(regr.predict(diabetes_X_test[0:5]))
    
    df = pd.DataFrame(regr.predict(diabetes_X_test[0:5]))
    
    print(df)
    
    [ 225.9732401   115.74763374  163.27610621  114.73638965  120.80385422]
                0
    0  225.973240
    1  115.747634
    2  163.276106
    3  114.736390
    4  120.803854
    

    【讨论】:

    • 如果 Xtest 是随机选择并且索引是随机的,如何解决相同的问题,在这种情况下,我们无法匹配两个数据帧中的每条记录
    猜你喜欢
    • 2016-04-25
    • 2016-04-01
    • 2017-05-07
    • 1970-01-01
    • 2023-03-09
    • 2018-07-29
    • 2015-11-13
    • 1970-01-01
    • 2016-11-26
    相关资源
    最近更新 更多