【发布时间】:2021-01-14 16:10:27
【问题描述】:
我正在检查tensorflow.keras 中非常简单的指标对象,例如BinaryAccuracy 或AUC。它们都有 reset_states() 和 update_state() 参数,但我发现它们的文档不足且不清楚。
你能解释一下它们的意思吗?
【问题讨论】:
标签: python tensorflow keras tensorflow2.0 metrics
我正在检查tensorflow.keras 中非常简单的指标对象,例如BinaryAccuracy 或AUC。它们都有 reset_states() 和 update_state() 参数,但我发现它们的文档不足且不清楚。
你能解释一下它们的意思吗?
【问题讨论】:
标签: python tensorflow keras tensorflow2.0 metrics
update_state 测量指标(mean、auc、accuracy),并将它们存储在对象中,以便以后可以使用result 检索:
import tensorflow as tf
mean_object = tf.metrics.Mean()
values = [1, 2, 3, 4, 5]
for ix, val in enumerate(values):
mean_object.update_state(val)
print(mean_object.result().numpy(), 'is the mean of', values[:ix+1])
1.0 is the mean of [1]
1.5 is the mean of [1, 2]
2.0 is the mean of [1, 2, 3]
2.5 is the mean of [1, 2, 3, 4]
3.0 is the mean of [1, 2, 3, 4, 5]
reset_states 将指标重置为零:
mean_object.reset_states()
mean_object.result().numpy()
0.0
我不确定我是否比文档更清楚,我认为它已经很好地解释了。
调用对象,例如mean_object([1, 2, 3, 4]) 将更新指标,并且返回result。
import tensorflow as tf
mean_object = tf.metrics.Mean()
values = [1, 2, 3, 4, 5]
print(mean_object.result())
returned_mean = mean_object(values)
print(mean_object.result())
print(returned_mean)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
【讨论】:
update_state() 而不是简单的accuracy_object(x, y)。这可以在tf.keras 中完成。它是一样的,还是不同的东西?如果是这样,为什么?
mean_object.update_state(a, b) 而不仅仅是调用mean_object(a, b)?见编辑