【问题标题】:Update model parameters using parsnip without re-specifying arguments for the fit function使用欧洲防风草更新模型参数,而无需重新指定拟合函数的参数
【发布时间】:2020-05-01 04:08:47
【问题描述】:

我正在使用 Titanic 数据集试验欧洲防风草包。

library(titanic)
library(dplyr)
library(tidymodels)
library(rattle)
library(rpart.plot)
library(RColorBrewer)

train <- titanic_train %>%
  mutate(Survived = factor(Survived),
         Sex = factor(Sex),
         Embarked = factor(Embarked)) 

test <- titanic_test %>%
  mutate(Sex = factor(Sex),
         Embarked = factor(Embarked)) 

spec_obj <-
  decision_tree(mode = "classification") %>% 
  set_engine("rpart")
spec_obj

fit_obj <- 
  spec_obj %>% 
  fit(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, data = train)
fit_obj

fancyRpartPlot(fit_obj$fit)

pred <- 
  fit_obj %>%
  predict(new_data = test)
pred

假设我想在我的模型函数中添加一些参数。

spec_obj <- update(spec_obj, min_n = 50, cost_complexity = 0)
fit_obj <-
  spec_obj %>% 
  fit(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, data = train)
fit_obj
fancyRpartPlot(fit_obj$fit)

有什么方法可以避免在fit() 函数中再次指定模型和数据集?

==============编辑================
我发现您可以将公式保存在变量中:

f <- as.formula("Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked")
fit_obj <-
  spec_obj %>%
  fit(f, data = train)
fit_obj

还有更好的办法吗?

【问题讨论】:

    标签: r tidymodels


    【解决方案1】:

    我认为最好的方法是创建一个小包装函数,可能称为fit_titanic()

    library(titanic)
    library(dplyr)
    #> 
    #> Attaching package: 'dplyr'
    #> The following objects are masked from 'package:stats':
    #> 
    #>     filter, lag
    #> The following objects are masked from 'package:base':
    #> 
    #>     intersect, setdiff, setequal, union
    library(tidymodels)
    #> ── Attaching packages ────────────────────────────────────────────── tidymodels 0.1.0 ──
    #> ✓ broom     0.5.5      ✓ recipes   0.1.10
    #> ✓ dials     0.0.6      ✓ rsample   0.0.6 
    #> ✓ ggplot2   3.3.0      ✓ tibble    3.0.1 
    #> ✓ infer     0.5.1      ✓ tune      0.1.0 
    #> ✓ parsnip   0.1.0      ✓ workflows 0.1.1 
    #> ✓ purrr     0.3.4      ✓ yardstick 0.0.6
    #> ── Conflicts ───────────────────────────────────────────────── tidymodels_conflicts() ──
    #> x purrr::discard()  masks scales::discard()
    #> x dplyr::filter()   masks stats::filter()
    #> x dplyr::lag()      masks stats::lag()
    #> x ggplot2::margin() masks dials::margin()
    #> x recipes::step()   masks stats::step()
    
    train <- titanic_train %>%
      mutate(Survived = factor(Survived),
             Sex = factor(Sex),
             Embarked = factor(Embarked)) 
    
    
    spec1 <-
      decision_tree(mode = "classification") %>% 
      set_engine("rpart")
    
    spec1
    #> Decision Tree Model Specification (classification)
    #> 
    #> Computational engine: rpart
    
    fit_titanic <- function(spec) {
      fit(spec, 
          Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, 
          data = train)
    }
    
    
    fit_titanic(spec1)
    #> parsnip model object
    #> 
    #> Fit time:  17ms 
    #> n= 891 
    #> 
    #> node), split, n, loss, yval, (yprob)
    #>       * denotes terminal node
    #> 
    #>   1) root 891 342 0 (0.61616162 0.38383838)  
    #>     2) Sex=male 577 109 0 (0.81109185 0.18890815)  
    #>       4) Age>=6.5 553  93 0 (0.83182640 0.16817360) *
    #>       5) Age< 6.5 24   8 1 (0.33333333 0.66666667)  
    #>        10) SibSp>=2.5 9   1 0 (0.88888889 0.11111111) *
    #>        11) SibSp< 2.5 15   0 1 (0.00000000 1.00000000) *
    #>     3) Sex=female 314  81 1 (0.25796178 0.74203822)  
    #>       6) Pclass>=2.5 144  72 0 (0.50000000 0.50000000)  
    #>        12) Fare>=23.35 27   3 0 (0.88888889 0.11111111) *
    #>        13) Fare< 23.35 117  48 1 (0.41025641 0.58974359)  
    #>          26) Embarked=S 63  31 0 (0.50793651 0.49206349)  
    #>            52) Fare< 10.825 37  15 0 (0.59459459 0.40540541) *
    #>            53) Fare>=10.825 26  10 1 (0.38461538 0.61538462)  
    #>             106) Fare>=17.6 10   3 0 (0.70000000 0.30000000) *
    #>             107) Fare< 17.6 16   3 1 (0.18750000 0.81250000) *
    #>          27) Embarked=C,Q 54  16 1 (0.29629630 0.70370370) *
    #>       7) Pclass< 2.5 170   9 1 (0.05294118 0.94705882) *
    
    spec2 <- update(spec1, min_n = 50, cost_complexity = 0)
    
    fit_titanic(spec2)
    #> parsnip model object
    #> 
    #> Fit time:  10ms 
    #> n= 891 
    #> 
    #> node), split, n, loss, yval, (yprob)
    #>       * denotes terminal node
    #> 
    #>  1) root 891 342 0 (0.61616162 0.38383838)  
    #>    2) Sex=male 577 109 0 (0.81109185 0.18890815)  
    #>      4) Age>=6.5 553  93 0 (0.83182640 0.16817360) *
    #>      5) Age< 6.5 24   8 1 (0.33333333 0.66666667) *
    #>    3) Sex=female 314  81 1 (0.25796178 0.74203822)  
    #>      6) Pclass>=2.5 144  72 0 (0.50000000 0.50000000)  
    #>       12) Fare>=23.35 27   3 0 (0.88888889 0.11111111) *
    #>       13) Fare< 23.35 117  48 1 (0.41025641 0.58974359)  
    #>         26) Embarked=S 63  31 0 (0.50793651 0.49206349)  
    #>           52) Fare< 10.825 37  15 0 (0.59459459 0.40540541) *
    #>           53) Fare>=10.825 26  10 1 (0.38461538 0.61538462) *
    #>         27) Embarked=C,Q 54  16 1 (0.29629630 0.70370370) *
    #>      7) Pclass< 2.5 170   9 1 (0.05294118 0.94705882) *
    

    reprex package (v0.3.0) 于 2020 年 4 月 30 日创建

    【讨论】:

    • 非常好!谢谢!
    猜你喜欢
    • 1970-01-01
    • 2020-08-19
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-03-05
    • 2020-08-26
    • 1970-01-01
    相关资源
    最近更新 更多