【发布时间】:2025-12-10 09:15:01
【问题描述】:
我在文档中找不到信息,所以我在这里问。
我有一个具有 3 个不同输出的多输出模型:
model = tf.keras.Model(inputs=[input], outputs=[output1, output2, output3])
用于验证的预测标签由这 3 个输出构成,仅形成一个,这是一个后处理步骤。用于训练的数据集是这 3 种中间输出的数据集,为了验证,我在标签数据集而不是 3 种中间数据上进行评估。
我想使用一个自定义指标来评估我的模型,该指标处理后处理并与基本事实进行比较。
我的问题是,在自定义指标的代码中,y_pred 会是模型的 3 个输出的列表吗?
class MyCustomMetric(tf.keras.metrics.Metric):
def __init__(self, name='my_custom_metric', **kwargs):
super(MyCustomMetric, self).__init__(name=name, **kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
# ? is y_pred a list [batch_output_1, batch_output_2, batch_output_3] ?
def result(self):
pass
# one single metric handling the 3 outputs?
model.compile(optimizer=tf.compat.v1.train.RMSPropOptimizer(0.01),
loss=tf.keras.losses.categorical_crossentropy,
metrics=[MyCustomMetric()])
【问题讨论】:
-
你运行了这个,看看它是否有效?如果是这样,您遇到了什么错误?
-
不,我没有,这只是一段看起来像我想要的地方但缺少部分的代码。我正在寻找一个可以帮助我填补缺失部分的答案。
标签: python tensorflow keras tensorflow2.0