【问题标题】:How to compute FIRM importance measure using VIP package and tidymodels (including recipe)如何使用 VIP 包和 tidymodels(包括配方)计算 FIRM 重要性度量
【发布时间】:2021-03-18 16:04:13
【问题描述】:

我想计算由 tidymodels 工作流程制成的模型的 FIRM 重要性分数。对于正则表达式,我将使用 iris 数据集并尝试预测观察结果是否为 setosa。

library(tidymodels)
library(readr)
library(vip)

#clean data
iris <- iris %>%
  mutate(class  = case_when(Species == 'setosa' ~ 'setosa',
                            TRUE ~ 'other'))
iris$class = as.factor(iris$class)
iris <- subset(iris, select = -c(Species))

#split data into training and testing
iris_split = initial_split(iris, prop = 0.8)
cv_splits = vfold_cv(training(iris_split), v = 5)

#preprocessing
iris_recipe = recipe(class ~., data = iris) %>%
  step_center(Sepal.Length) %>%
  prep()

#specify MARS model
model = rand_forest(
  mode = "classification",
  mtry = tune(),
  trees = 50
) %>% 
  set_engine("ranger", importance = "impurity")

#tuning parameters
tuning_grid = grid_regular(mtry(range=c(1,4)), levels = 4)

iris_wkfl = workflow() %>%
  add_recipe(iris_recipe) %>%
  add_model(model) 
  
iris_tune = tune_grid(iris_wkfl,
            resamples = cv_splits,
            grid = tuning_grid,
            metrics = metric_set(accuracy))

best_params = iris_tune %>%
  select_best(metric = "accuracy")

best_model = finalize_workflow(iris_wkfl, best_params) %>%
  parsnip::fit(data = training(iris_split)) %>%
  pull_workflow_fit()

vip(best_model, method = "firm")

最后一行从 pdp 包中产生错误。

get_training_data.default(object) 中的错误: 无法从对象中提取训练数据。请在对partial 的调用中使用train 参数提供原始训练数据。

下面这行正确吗?还是我需要先使用我的配方提供转换后的训练数据?我想确保 vip 在计算重要性分数时应用我的配方。我知道错误是“原始训练数据”,但我不确定 pdp 是否知道我的工作流程。

vip(best_model, method = "firm", train = training(iris_split))

【问题讨论】:

    标签: r machine-learning tidymodels vip


    【解决方案1】:

    你会想采用与我outlined in this answer相同的方法。

    首先调整模型,然后根据训练数据训练模型:

    library(tidymodels)
    
    #clean data
    iris = iris %>%
      mutate(class  = case_when(Species == 'setosa' ~ 'setosa',
                                TRUE ~ 'other'),
             class = factor(class)) %>%
      select(-Species)
    
    #split data into training and testing
    iris_split = initial_split(iris, prop = 0.8)
    iris_train = training(iris_split)
    iris_test = testing(iris_split)
    cv_splits = vfold_cv(iris_train, v = 5)
    
    #preprocessing
    iris_recipe = recipe(class ~., data = iris_train) %>%
      step_center(Sepal.Length)
    
    #specify ranger model
    rf_spec = rand_forest(
      mode = "classification",
      mtry = tune(),
      trees = 50
    ) %>% 
      set_engine("ranger", importance = "impurity") 
    ## don't need any importance here if you will do it another way; probably remove
    
    #tuning parameters
    tuning_grid = grid_regular(mtry(range=c(1,4)), levels = 4)
    
    iris_wkfl = workflow() %>%
      add_recipe(iris_recipe) %>%
      add_model(rf_spec) 
    
    iris_tune = tune_grid(iris_wkfl,
                          resamples = cv_splits,
                          grid = tuning_grid,
                          metrics = metric_set(accuracy))
    #> 
    #> Attaching package: 'rlang'
    #> The following objects are masked from 'package:purrr':
    #> 
    #>     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
    #>     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
    #>     splice
    #> 
    #> Attaching package: 'vctrs'
    #> The following object is masked from 'package:tibble':
    #> 
    #>     data_frame
    #> The following object is masked from 'package:dplyr':
    #> 
    #>     data_frame
    
    best_params = iris_tune %>%
      select_best(metric = "accuracy")
    
    rf_fit = finalize_workflow(iris_wkfl, best_params) %>%
      fit(data = iris_train)
    

    您的模型现已训练,您可以计算与模型无关的变量重要性分数,例如 FIRM。有几个步骤:

    • pull() 已将拟合模型排除在工作流程之外。
    • 您必须指定目标/结果变量class
    • 在这种情况下,我们需要同时传递原始训练数据(您必须在此处使用预处理的 training 数据,您可以从配方中获得)和正确的基础函数以从 ranger 进行预测 (对于大多数型号,这是predict(),但不幸的是对于游侠,它是predictions()
    library(vip)
    #> 
    #> Attaching package: 'vip'
    #> The following object is masked from 'package:utils':
    #> 
    #>     vi
    rf_fit %>%
      pull_workflow_fit() %>%
      vip(method = "firm", 
          target = "class", metric = "accuracy",
          pred_wrapper = ranger::predictions, 
          train = bake(prep(iris_recipe), new_data = NULL))
    

    reprex package (v0.3.0.9001) 于 2020 年 12 月 10 日创建

    【讨论】:

      猜你喜欢
      • 2021-08-22
      • 2020-08-19
      • 1970-01-01
      • 1970-01-01
      • 2021-12-11
      • 2021-01-04
      • 2021-06-20
      • 2017-10-20
      • 2021-11-11
      相关资源
      最近更新 更多