【问题标题】:Stacking models from different packages堆叠来自不同包的模型
【发布时间】:2018-04-14 02:06:36
【问题描述】:

我正在使用 e1071 包中的 SVM 进行破产预测(分类)。为了改善我的结果,我想将它与 caret 包中的随机森林结合起来。首先,我将展示我的 RF 模型,然后我将展示 SVM 模型。之后,我将展示我对组合(堆叠)它们的尝试。

抱歉,代码混乱。我对这一切都不熟悉。

RF 模型(插入符号包)

set.seed(123)
model.rf <- train(as.factor(year.of.bankruptcy) ~ ., method = "rf", data = training.set)
predict.rf <- predict(model.rf, testing.set[,-1])

RF 模型精度

confusionMatrix(predict.rf, testing.set$year.of.bankruptcy, mode="everything")$overall[1]

-> 这给了我模型的准确性: 准确性 0.7166667

SVM(e1071 包)

set.seed(123)
model1<-function(k,d,c,g){
  model <-svm(year.of.bankruptcy ~., data = training.set, type = "C-classification", kernel = k, degree= d, cost =c, gamma =g)
  1<-testing.set[,-1]
  2<-testing.set$year.of.bankruptcy
  model_prediction <- predict(model, 1)
  result<-table(model_prediction, 2)
  return(result)
}

result<-model1(k="radial", d=2, c=2,g=0.1)
result
classAgreement(tab=result, match.names = FALSE)
classAgreement(tab=result, match.names = FALSE)$diag

-> 这给了我模型的准确性: [1] 0.7466667

将模型堆叠在一起

predictDF <- data.frame(predict.rf, classAgreement(tab=result, match.names = FALSE)$diag, class = testing.set$year.of.bankruptcy)
predictDF_bc <- ROSE(class ~.,predictDF, N=300, p=0.5, seed=12)$data

set.seed(123)

combined.model.gbm <- train(as.factor(class) ~ ., method = "gbm", data = predictDF_bc, distribution = "bernoulli")
combined.prediction.gbm <- predict(combined.model.gbm, predictDF)

评估模型

confusionMatrix(combined.prediction.gbm, testing.set$year.of.bankruptcy)$overall[1]`enter code here`

-> 这给了我堆叠模型的准确性: 准确性 0.7166667

如您所见,组合模型没有考虑 SVM。因为我的综合分数低于我的 SVM 分数。对我能做什么有什么建议吗?

  > dput(training.set[sample(1:nrow(training.set), 50),])

structure(list(year.of.bankruptcy = c(-1, -1, -1, -1, -1, -1, -1, 1, -1, 
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
-1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
-1, -1, -1, 1, -1, -1, -1, -1, -1), liquidity_1 = c(90.0695931477516, 
85.4305617311398, 76.2455934195065, 4.34688111280157, 159.020111900801, 
104.569486404834, 58.3391003460208, 42.0907973873116, 101.121495327103, 
94.3786295005807, 47.7552816901408, 125.702184574231, 125.763725699637, 
106.584557081952, 0, 143.6203466894, 82.5245328673209, 35.296442687747, 
8.85744561490993, 12.4657534246575, 128.164489183979, 133.131146034372, 
92.0528568769775, 22.8177150192555, 100.237812128419, 40.0340715502555, 
91.360486091332, 129.123757904246, 92.9165443694355, 130.999694283094, 
22.2526106414719, 101.714770797963, 93.1704260651629, 46.6268560361524, 
125.838858750251, 106.076759061834, 86.787017476474, 84.7495991700462, 
42.1171171171171, 68.806311160926, 93.1549687282835, 104.196667352397, 
47.0834921845215, 77.8816199376947, 76.9065981148243, 90.988709507228, 
98.9704873026767, 163.446031970576, 113.768115942029, 92.9742188833874
), profmarg_1 = c(241.916488222698, 215.221579961464, 633.490011750881, 
0, 173.627703009224, 193.164652567976, 3.32179930795848, 82.390221819828, 
131.842456608812, 102.044134727062, 0, 7.2447614801605, 113.608203375347, 
169.208905731881, 0, 179.866439329355, 250.396558677242, 48.0632411067194, 
0, 12.8082191780822, 0.963803812379525, 0, 452.279918109064, 
0, 16.4090368608799, 11.4449434722007, 173.331434539068, 240.216802168022, 
307.709617454261, 179.883827575665, 281.476877175535, 539.609507640068, 
183.12447786132, 31.8431245965139, 151.215591721921, 95.3980099502487, 
259.97695410025, 174.073375459776, 11.986986986987, 160.94322541708, 
119.110493398193, 428.03949804567, 194.624475791079, 325.877466251298, 
37.2322193658955, 245.71066793289, 207.343857240906, 22.49257320696, 
43.6487638533674, 97.4987194809629), drmarg = c(1.46603230803275, 
12.6575304731079, -0.798553144129104, 53.3333333333333, 11.8097892353249, 
29.1893259137473, 60.4166666666667, -23.041601255887, 1.21518987341772, 
6.1535019019915, 82.4626865671642, -4, 4.47536667920271, -3.69540873460246, 
65.3543307086614, 6.46738701790362, -3.63987759703656, 0.575657894736842, 
70.2460850111857, 45.4545454545455, -724.444444444444, 18.809947734191, 
3.22818215293973, 92.9292929292929, 6.52173913043478, 50.8680555555556, 
4.88031987730733, 19.9684115523466, 1.1446376903755, 13.3729821580289, 
1.22027317479027, 4.0986955838441, -3.29607664233577, 73.4414597060314, 
3.95960669678448, 28.6645874681032, 17.2991867598802, 10.8455534851063, 
55.741127348643, 8.98526582981339, 7.36196319018405, 4.85894170231172, 
10.4852855193919, -1.6774275224712, 16.3210702341137, 2.47726693294808, 
5.64784053156146, 59.622641509434, 11.0029211295034, 50.5987773218323
), ROA = c(3.546573875803, 27.2417370683267, -5.05875440658049, 
6.52032166920235, 20.5050657795252, 87.1601208459215, 2.00692041522491, 
-18.9840263855655, 1.60213618157543, 6.38792102206736, 9.72711267605634, 
-0.356665180561748, 5.08438367870113, -6.25296068214116, 3.53041259038707, 
11.6510372264848, -9.11412824304342, 0.276679841897233, 5.87171975316337, 
5.82191780821918, -6.98222317412722, 30.0983365499495, 14.6845337800112, 
11.8100128369705, 1.07015457788347, 6.05028134840741, 8.45912845343207, 
47.9674796747967, 3.52216025829175, 24.0599205136044, 4.37593237195425, 
22.1392190152801, -6.0359231411863, 23.3860555196901, 5.98754269640346, 
35.9275053304904, 46.5719224121375, 18.9380364047911, 6.68168168168168, 
19.5326981937319, 9.17303683113273, 20.7981896729068, 20.5108654212734, 
-5.50363447559709, 10.4541559554413, 6.15173578136541, 12.4456646076413, 
13.4106662894327, 4.81670929241262, 51.5793068123613), debt_ratio_1 = c(75.6423982869379, 
157.077219504965, 180.975323149236, 88.958921973484, 96.869801905338, 
93.0513595166163, 78.6159169550173, 131.707948004915, 132.096128170895, 
100.789779326365, 28.080985915493, 48.1497993758359, 85.6868190557573, 
85.5518711511132, 75.4714305969091, 92.0431940892299, 123.551552628041, 
43.8735177865613, 89.2601134451162, 69.0547945205479, 29.727993146284, 
110.265600588181, 154.662199888331, 54.2362002567394, 20.9274673008323, 
79.0666460172423, 150.536409380044, 101.355013550135, 145.827218471774, 
45.2155304188322, 123.222277473894, 134.90662139219, 123.141186299081, 
41.7043253712072, 66.2648181635523, 26.5813788201848, 95.1411561359708, 
105.191926813166, 7.60760760760761, 179.997413458637, 92.7032661570535, 
121.49763423164, 96.3400686237133, 129.823468328141, 39.502999143102, 
136.213991769547, 119.01166781057, 84.8210496534163, 8.99403239556692, 
113.957657503842), young = c(1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1), medium_age = c(0, 
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 
0, 1, 0, 1, 0, 0, 0), old = c(0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 
1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 
1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0), agriculture = c(0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0), offshore_shipping = c(0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0), transport = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), manufacturing = c(0, 
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0), telecom_it_tech = c(0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0), electricity = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), construction = c(0, 
0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 
1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 1), wholesale_retail = c(0, 0, 1, 0, 1, 0, 
0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 
0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 
0, 0), finance = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), change_output = c(0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 0.0495184733103549, 
0.0495184733103549), oil_price_dummy = c(0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0), fish_price_dummy = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0.180737819481274, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.180737819481274, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0)), .Names = c("year.of.bankruptcy", "liquidity_1", 
"profmarg_1", "drmarg", "ROA", "debt_ratio_1", "young", "medium_age", 
"old", "agriculture", "offshore_shipping", "transport", "manufacturing", 
"telecom_it_tech", "electricity", "construction", "wholesale_retail", 
"finance", "change_output", "oil_price_dummy", "fish_price_dummy"
), row.names = c(19L, 49L, 25L, 53L, 56L, 3L, 31L, 50L, 58L, 
62L, 51L, 24L, 35L, 29L, 6L, 44L, 12L, 2L, 15L, 42L, 39L, 30L, 
27L, 40L, 26L, 41L, 21L, 22L, 11L, 63L, 32L, 60L, 36L, 52L, 1L, 
14L, 37L, 34L, 8L, 43L, 4L, 10L, 9L, 54L, 59L, 64L, 23L, 20L, 
17L, 13L), class = "data.frame")

【问题讨论】:

    标签: r merge classification svm r-caret


    【解决方案1】:

    使用caretEnsemble 库可以很容易地堆叠模型。
    这是一个例子:

    library(mlbench) #for the data set
    library(caret)
    library(caretEnsemble)
    
    data(PimaIndiansDiabetes)
    set.seed(123)
    

    列出要使用的算法:

    algorithmList <- c("svmRadial", "rf" ) 
    

    如果您想在每个模型中指定调整参数,请在 caretList 函数中使用 tuneList 参数:

    trainControlsavePredictions = "final"classProbs = TRUE 是必填项

    control <- trainControl(method = "repeatedcv", number = 4, repeats = 3, 
                            savePredictions = "final" , classProbs = TRUE)
    
    models <- caretList(diabetes ~ ., data = PimaIndiansDiabetes, trControl = control,
                         metric = "Kappa", methodList = algorithmList)
    
    results <- resamples(models)
    
    summary(results)
    #output
    Call:
    summary.resamples(object = results)
    
    Models: svmRadial, rf 
    Number of resamples: 12 
    
    Accuracy 
                   Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
    svmRadial 0.6979167 0.7135417 0.7343750 0.7304688 0.7447917 0.7604167    0
    rf        0.7291667 0.7604167 0.7682292 0.7690972 0.7760417 0.8125000    0
    
    Kappa 
                   Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
    svmRadial 0.2637842 0.3570103 0.4053130 0.3917770 0.4394767 0.4775359    0
    rf        0.3788379 0.4612661 0.4788076 0.4809233 0.5028566 0.5785880    0
    

    现在是堆栈,

    stack.glm <- caretStack(models, method = "glm", metric = "Kappa", trControl = control)
    print(stack.glm)
    #output
    A glm ensemble of 2 base models: svmRadial, rf
    
    Ensemble results:
    Generalized Linear Model 
    
    2304 samples
       2 predictor
       2 classes: 'neg', 'pos' 
    
    No pre-processing
    Resampling: Cross-Validated (4 fold, repeated 3 times) 
    Summary of sample sizes: 1728, 1728, 1728, 1728, 1728, 1728, ... 
    Resampling results:
    
      Accuracy   Kappa    
      0.7667824  0.4685406
    

    或 gbm 堆栈

    stack.gbm <- caretStack(models, method="gbm", metric = "Kappa", trControl = control)
    
    print(stack.gbm)
    #output
    A gbm ensemble of 2 base models: svmRadial, rf
    
    Ensemble results:
    Stochastic Gradient Boosting 
    
    2304 samples
       2 predictor
       2 classes: 'neg', 'pos' 
    
    No pre-processing
    Resampling: Cross-Validated (4 fold, repeated 3 times) 
    Summary of sample sizes: 1728, 1728, 1728, 1728, 1728, 1728, ... 
    Resampling results across tuning parameters:
    
      interaction.depth  n.trees  Accuracy   Kappa    
      1                   50      0.7693866  0.4832061
      1                  100      0.7675058  0.4785977
      1                  150      0.7663484  0.4753614
      2                   50      0.7662037  0.4748160
      2                  100      0.7638889  0.4684015
      2                  150      0.7634549  0.4653090
      3                   50      0.7630208  0.4657834
      3                  100      0.7612847  0.4606506
      3                  150      0.7569444  0.4511977
    
    Tuning parameter 'shrinkage' was held constant at a value of 0.1
    Tuning parameter 'n.minobsinnode' was
     held constant at a value of 10
    Kappa was used to select the optimal model using  the largest value.
    The final values used for the model were n.trees = 50, interaction.depth = 1, shrinkage = 0.1 and n.minobsinnode
     = 10.
    

    因此,
    svm 的 kappa 值:0.3917770
    rf:0.4809233
    glm 合奏:0.4685406
    gbm 合奏:0.4832061 - 如果更多使用了模型

    编辑:使用 OP 中提供的数据:

    首先将year.of.bankruptcy 转换为因子

    data$year.of.bankruptcy <- as.factor(data$year.of.bankruptcy)
    

    将级别名称设置为不会引发错误的内容:

    levels(data$year.of.bankruptcy) <- c("minus", "plus")
    

    继续前进

    control <- trainControl(method = "repeatedcv", number = 4, repeats = 3, 
                            savePredictions = "final" , classProbs = TRUE)
    
    models <- caretList(year.of.bankruptcy ~ ., data = data, trControl = control,
                        metric = "Kappa", methodList = algorithmList)
    

    我收到有关零方差预测变量的警告,但这可能是由小数据样本引起的。如果您看到如下错误:

    In .local(x, ...) : Variable(s) `' constant. Cannot scale data.
    

    在整个数据集上,那么值得研究移除接近零方差的预测变量。关于这个here 有一个很好的章节。祝你好运

    【讨论】:

      猜你喜欢
      • 2018-07-22
      • 2020-09-09
      • 2010-12-06
      • 2019-09-20
      • 1970-01-01
      • 2013-05-29
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多