【问题标题】:survxai explainer with an mlr3proba model带有 mlr3proba 模型的 survxai 解释器
【发布时间】:2022-06-07 23:57:41
【问题描述】:

我正在尝试从使用 mlr3proba 构建的生存模型构建一个 survxai 解释器。我在创建解释器所需的 predict_function 时遇到问题。有没有人尝试过构建这样的东西?

到目前为止,我的代码如下:

require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)

create_pipeops <- function(learner) {
  GraphLearner$new(po(\"encode\") %>>% po(\"scale\") %>>% po(\"learner\", learner))
}

fit<-lrn(\"surv.deepsurv\")
fit<-create_pipeops(fit)

data<-veteran
survival_task<-TaskSurv$new(\"veteran\", veteran, time = \"time\", event = \"status\")
fit$train(survival_task)

predict_function<-function(model, newdata, times=NULL){
  if(!is.data.frame(newdata)){
    newdata <- data.frame(newdata)
  }
  surv_task<-TaskSurv$new(\"task\", newdata, time = \"time\", 
                          event = \"status\")
  pred<-model$predict(surv_task)
  mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
  colnames(mat)<-colnames(pred$data$distr)
  return(mat)
}

explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
                            y = Surv(veteran$time, veteran$status),
                            predict_function = predict_function)

pred_breakdown<-prediction_breakdown(explainer, veteran[1,])

它会引发以下错误:[.data.table(r6_private(backend)$.data, , event, with = FALSE) 中的错误: 未找到列:状态,但我怀疑一旦解决了该列,可能还会有更多。我不完全理解函数返回的对象的结构。

在 predict_function 中,我包含了 times 参数,因为根据 R 帮助页面,该函数必须采用三个参数。

  • 你的代码对我来说运行良好。您可以使用 reprex::reprex 提供一个代表吗?
  • @RaphaelS 我编辑了问题,我忘了实际添加创建解释器的代码。我实际上也在尝试使用survivalmodels::deepsurv,因为我更熟悉该语法,但到目前为止该软件包也没有运气。

标签: r survival-analysis survival mlr3 dalex


【解决方案1】:

此处使用 randomForestSRC 的工作示例,您可以将 surv.rfsrc 更改为 surv.deepsurv 作为您的示例。顺便说一句,我们计划很快在 mlr3proba 中实现它,或者我可能只是将它直接添加到生存模型中,仍在决定!

library(mlr3proba)
#> Loading required package: mlr3
#> Warning: package 'mlr3' was built under R version 4.1.3
library(mlr3extralearners)
#> 
#> Attaching package: 'mlr3extralearners'
#> The following objects are masked from 'package:mlr3':
#> 
#>     lrn, lrns
library(survxai)
#> Loading required package: prodlim
#> Welcome to survxai (version: 0.2.1).
#> Information about the package can be found in the GitHub repository: https://github.com/MI2DataLab/survxai
library(survival)
data(pbc, package = "randomForestSRC")
pbc <- pbc[complete.cases(pbc), ]
task <- as_task_surv(pbc, event = "status", time = "days")
split <- partition(task)
predict_times <- function(model, data, times) {
  t(model$predict_newdata(data)$distr$survival(times))
}
model <- lrn("surv.rfsrc")$train(task, row_ids = split$train)
surve_cph <- explain(
  model = model, data = pbc[, -c(1, 2)],
  y = Surv(pbc$days, pbc$status),
  predict_function = predict_times
)
prediction_breakdown(surve_cph, pbc[1, -c(1, 2)])
#>             contribution
#> bili            -35.079%
#> edema           -10.278%
#> ascites          -5.505%
#> copper           -1.084%
#> stage            -0.773%
#> prothrombin      -0.421%
#> albumin          -0.247%
#> sgot             -0.143%
#> hepatom          -0.098%
#> spiders          -0.086%
#> alk              -0.043%
#> trig             -0.041%
#> age              -0.035%

reprex package (v2.0.1) 于 2022-06-07 创建

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-08-18
    • 1970-01-01
    • 2011-12-12
    • 1970-01-01
    • 2021-05-03
    • 2016-01-12
    • 2010-10-24
    相关资源
    最近更新 更多