【问题标题】:R linear model (lm) predict function with one single arrayR 线性模型 (lm) 用一个数组预测函数
【发布时间】:2016-09-16 04:10:29
【问题描述】:

我在 R 中有一个 lm 模型,我已经对其进行了训练和序列化。在函数内部,我将模型和特征向量(一个数组)作为输入传递,我有:

CREATE OR REPLACE FUNCTION lm_predict(
    feat_vec float[],
    model bytea
)
RETURNS float
AS
$$
    #R-code goes here.
    mdl <- unserialize(model)
    # class(feat_vec) outputs "array"
    y_hat <- predict.lm(mdl, newdata = as.data.frame.list(feat_vec))
    return (y_hat)
$$ LANGUAGE 'plr';

这会返回错误的y_hat!!我知道这一点是因为这个其他解决方案有效(这个函数的输入仍然是模型(在字节数组中)和一个 feat_vec(数组)):

CREATE OR REPLACE FUNCTION lm_predict(
    feat_vec float[],
    model bytea
)
RETURNS float
AS
$$
    #R-code goes here.
    mdl <- unserialize(model)
    coef = mdl$coefficients
    y_hat = coef[1] + as.numeric(coef[-1]%*%feat_vec)
    return (y_hat)
$$ LANGUAGE 'plr';

我做错了什么??它是相同的未序列化模型,第一个选项也应该给我正确的答案...

【问题讨论】:

  • 这是R代码吗?它看起来像半条蟒蛇;冒号在 R 中不起作用,return+ 也不起作用。
  • 是的,它是 R + 伪代码 - 你可以忽略函数声明 实际上 - 这是 Postgres 的 PL/R 函数内部,但我不想把重点放在 Postgres 上
  • ...那么伪代码如何返回结果,正确与否?
  • 我对我的问题进行了一些修改,希望现在很清楚。第一个选项返回错误的数字,而第二个选项返回正确的预测!但是我没有错误
  • 更好,但如果没有a reproducible example,仍然无法回答。

标签: r postgresql lm predict plr


【解决方案1】:

问题似乎是newdata = as.data.frame.list(feat_vec) 的使用。正如您在previous question 中所讨论的,这会返回丑陋的列名。而当您调用 predict 时,newdata 的列名称必须与模型公式中的协变量名称一致。当您致电predict 时,您应该会收到一些警告消息。

## example data
set.seed(0)
x1 <- runif(20)
x2 <- rnorm(20)
y <- 0.3 * x1 + 0.7 * x2 + rnorm(20, sd = 0.1)

## linear model
model <- lm(y ~ x1 + x2)

## new data
feat_vec <- c(0.4, 0.6)
newdat <- as.data.frame.list(feat_vec)
#  X0.4 X0.6
#1  0.4  0.6

## prediction
y_hat <- predict.lm(model, newdata = newdat)
#Warning message:
#'newdata' had 1 row but variables found have 20 rows 

你需要的是

newdat <- as.data.frame.list(feat_vec,
                             col.names = attr(model$terms, "term.labels"))
#   x1  x2
#1 0.4 0.6

y_hat <- predict.lm(model, newdata = newdat)
#        1 
#0.5192413 

这和你可以手动计算的一样:

coef = model$coefficients
unname(coef[1] + sum(coef[-1] * feat_vec))
#[1] 0.5192413 

【讨论】:

  • 从 Postgres 调用 R 时,我看不到警告消息...但肯定有问题
  • 感谢您的回答。对此,我真的非常感激。尽管如此,它仍然对我不起作用, y_hat 返回始终相同的结果,而“手动”计算返回正确的预测。我不明白为什么:/ 为什么我需要包含 col.names?这真的很重要吗?
  • 这解决了我在使用 randomForest 时遇到的问题...谢谢!我仍然得到 lm 的奇怪行为,但很高兴我让它与另一个回归模型和完全相同的代码一起工作!
猜你喜欢
  • 2015-06-21
  • 1970-01-01
  • 1970-01-01
  • 2012-08-18
  • 1970-01-01
  • 2017-02-19
  • 1970-01-01
  • 2018-10-03
  • 1970-01-01
相关资源
最近更新 更多