【问题标题】:Implement a custom loss function in Tensorflow BoostedTreesEstimator在 TensorFlow BoostedTreesEstimator 中实现自定义损失函数
【发布时间】:2020-07-29 14:06:50
【问题描述】:

我正在尝试使用 Tensorflow“BoostedTreesRegressor”来实现提升模型。

为此,我需要实现一个自定义损失函数,在训练期间,将根据自定义函数中定义的逻辑计算损失,而不是使用通常的 mean_squared_error。

我在文章中读到,这可以通过指定头部来使用接口“BoostedTreesEstimator”来实现。因此,我尝试按如下方式实现我的模型:

#define custom loss function to calculate smape
def custom_loss_fn(labels, logits):
    return (np.abs(logits - labels) / (np.abs(logits) + np.abs(labels))) * 2


#create input functions
def make_input_fn(X, y, n_epochs=None, shuffle=True):
    def input_fn():
        dataset = tf.data.Dataset.from_tensor_slices((dict(X), y))
        if shuffle:
            dataset = dataset.shuffle(NUM_EXAMPLES)
        dataset = dataset.repeat(n_epochs)  
        dataset = dataset.batch(NUM_EXAMPLES)  
        return dataset
    return input_fn


train_input_fn = make_input_fn(dftrain, y_train)
eval_input_fn = make_input_fn(dfeval, y_eval, n_epochs=1, shuffle=False)

my_head = tf.estimator.RegressionHead(loss_fn=custom_loss_fn)

#Training a boosted trees model
est = tf.estimator.BoostedTreesEstimator(feature_columns,
                                         head=my_head,
                                         n_batches_per_layer=1,  
                                         n_trees=90,
                                         max_depth=2)

est.train(train_input_fn, max_steps=100)
predictions = list(est.predict(eval_input_fn))

此代码提供了如下错误: 'Head 的子类必须实现 create_estimator_spec() 或 'NotImplementedError: Head 的子类必须实现 create_estimator_spec() 或 _create_tpu_estimator_spec()。

正如我在文章中读到的,create_estimator_spec() 用于在创建新 Estimator 时定义 model_fn() 时使用。在这里,我不想创建任何新模型或 Estimator,我只想在训练时使用自定义损失函数(而不是默认均方误差),其中训练模型应该等于 BoostedTreesRegressor/BoostingTreesEstimator。

如果有人能给我一些实现这个模型的提示,那将是一个很大的帮助。

【问题讨论】:

    标签: python tensorflow boosting


    【解决方案1】:

    确保您没有在损失函数中使用 numpy 函数——您不能将张量转换为 numpy 数组。尝试用 tf.abs 替换 np.abs。您可能会收到 NotImplementedError,因为您的损失函数正在中断。

    【讨论】:

      猜你喜欢
      • 2019-12-16
      • 2021-06-07
      • 1970-01-01
      • 1970-01-01
      • 2018-10-28
      • 2017-12-29
      • 2019-06-13
      • 1970-01-01
      • 2018-12-28
      相关资源
      最近更新 更多