【发布时间】:2018-11-21 01:58:34
【问题描述】:
问题
我需要计算 Pearson 和 Spearman 相关性,并将其用作 tensorflow 中的指标。
对于 Pearson 来说,这很简单:
tf.contrib.metrics.streaming_pearson_correlation(y_pred, y_true)
但对于斯皮尔曼,我一无所知!
我尝试了什么:
来自this answer:
samples = 1
predictions_rank = tf.nn.top_k(y_pred, k=samples, sorted=True, name='prediction_rank').indices
real_rank = tf.nn.top_k(y_true, k=samples, sorted=True, name='real_rank').indices
rank_diffs = predictions_rank - real_rank
rank_diffs_squared_sum = tf.reduce_sum(rank_diffs * rank_diffs)
six = tf.constant(6)
one = tf.constant(1.0)
numerator = tf.cast(six * rank_diffs_squared_sum, dtype=tf.float32)
divider = tf.cast(samples * samples * samples - samples, dtype=tf.float32)
spearman_batch = one - numerator / divider
但是这个返回NaN...
我试过了:
size = tf.size(y_pred)
indice_of_ranks_pred = tf.nn.top_k(y_pred, k=size)[1]
indice_of_ranks_label = tf.nn.top_k(y_true, k=size)[1]
rank_pred = tf.nn.top_k(-indice_of_ranks_pred, k=size)[1]
rank_label = tf.nn.top_k(-indice_of_ranks_label, k=size)[1]
rank_pred = tf.to_float(rank_pred)
rank_label = tf.to_float(rank_label)
spearman = tf.contrib.metrics.streaming_pearson_correlation(rank_pred, rank_label)
但是运行这个我得到了以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: 输入 必须至少有 k 列。有 1 个,需要 32 个
[[{{节点指标/spearman/TopKV2}} = TopKV2[T=DT_FLOAT, sorted=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](lambda_1/add, metrics/pearson/pearson_r/variance_predictions/Size)]]
【问题讨论】:
标签: python python-3.x tensorflow metrics