【发布时间】:2022-06-07 23:57:41
【问题描述】:
我正在尝试从使用 mlr3proba 构建的生存模型构建一个 survxai 解释器。我在创建解释器所需的 predict_function 时遇到问题。有没有人尝试过构建这样的东西?
到目前为止,我的代码如下:
require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)
create_pipeops <- function(learner) {
GraphLearner$new(po(\"encode\") %>>% po(\"scale\") %>>% po(\"learner\", learner))
}
fit<-lrn(\"surv.deepsurv\")
fit<-create_pipeops(fit)
data<-veteran
survival_task<-TaskSurv$new(\"veteran\", veteran, time = \"time\", event = \"status\")
fit$train(survival_task)
predict_function<-function(model, newdata, times=NULL){
if(!is.data.frame(newdata)){
newdata <- data.frame(newdata)
}
surv_task<-TaskSurv$new(\"task\", newdata, time = \"time\",
event = \"status\")
pred<-model$predict(surv_task)
mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
colnames(mat)<-colnames(pred$data$distr)
return(mat)
}
explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
y = Surv(veteran$time, veteran$status),
predict_function = predict_function)
pred_breakdown<-prediction_breakdown(explainer, veteran[1,])
它会引发以下错误:[.data.table(r6_private(backend)$.data, , event, with = FALSE) 中的错误:
未找到列:状态,但我怀疑一旦解决了该列,可能还会有更多。我不完全理解函数返回的对象的结构。
在 predict_function 中,我包含了 times 参数,因为根据 R 帮助页面,该函数必须采用三个参数。
-
你的代码对我来说运行良好。您可以使用
reprex::reprex提供一个代表吗? -
@RaphaelS 我编辑了问题,我忘了实际添加创建解释器的代码。我实际上也在尝试使用
survivalmodels::deepsurv,因为我更熟悉该语法,但到目前为止该软件包也没有运气。
标签: r survival-analysis survival mlr3 dalex