【问题标题】:Accuracy and error rate of example Siamese network in KerasKeras中示例Siamese网络的准确率和错误率
【发布时间】:2020-04-11 19:44:45
【问题描述】:

我一直在关注这个例子here,我想知道这个准确度函数是如何工作的:

def compute_accuracy(y_true, y_pred):
'''Compute classification accuracy with a fixed threshold on distances.
'''
    pred = y_pred.ravel() < 0.5
    return np.mean(pred == y_true)

据我所知,在这种情况下,网络的输出将是两对之间的距离。那么在这种情况下我们如何计算准确率呢? “0.5”阈值指的是什么?另外,如何计算错误率?

【问题讨论】:

    标签: machine-learning keras deep-learning siamese-network


    【解决方案1】:

    对那个例子的理解似乎有一些空白,需要先填补:

    如果您研究数据准备步骤(即create_pairs 方法),您会意识到正对(即属于同一类的样本对)被分配一个标签 1(即正/真),而负对(即属于不同类别的样本对)的标签为 0(即负/假)。

    此外,示例中的连体网络被设计成给定一对样本作为输入,它将预测它们的距离作为输出。通过使用对比损失作为模型的损失函数,对模型进行训练,使得给定一个正对作为输入,预测一个小的距离值(因为它们属于同一类,因此它们的距离应该很低,即传达相似性)并且给定一个负对作为输入,预测一个大的距离值(因为它们属于差异类,因此它们的距离应该很高,即传达不同)。作为练习,尝试通过在代码中使用对比损失定义以数字方式考虑这些点(即当y_true 为 1 和y_true 为 0 时)来确认这些点。

    因此,示例中的准确度函数是这样实现的,即在预测距离值上应用一个固定的任意阈值,即 0.5,即y_pred(这意味着此示例的作者有决定小于 0.5 的距离值表示正对;您可能决定使用另一个阈值,但它应该是基于实验/经验的合理选择)。然后将结果与真正的标签值进行比较,即y_true:

    • y_pred 小于 0.5(y_pred &lt; 0.5 将等于 True):如果y_true 为 1(即正数),则这意味着网络的预测与真实情况一致标签(即True == 1 等于True),因此该样本的预测计入正确预测(即准确度)。但是,如果 y_true 为 0(即负数),则此样本的预测不正确(即 True == 0 等于 False),因此这不会有助于正确预测。

    • y_pred 等于或大于0.5 时(y_pred &lt; 0.5 将等于False):适用与上述相同的推理(留作练习!)。

    (注意:不要忘记模型是在批量样本上训练的。因此,y_predy_true 不是单个值;相反,它们是值数组,并且上面提到的所有计算/比较都是按元素应用的)。

    让我们看一个包含 5 个样本对的输入批次的(虚构的)数值示例,以及如何计算该批次模型预测的准确度:

    >>> y_pred = np.array([1.5, 0.7, 0.1, 0.3, 3.2])
    >>> y_true = np.array([1, 0, 0, 1, 0])
    
    >>> pred = y_pred < 0.5
    >>> pred
    array([False, False,  True,  True, False])
    
    >>> result = pred == y_true
    >>> result
    array([False,  True, False,  True,  True])
    
    >>> accuracy = np.mean(result)
    >>> accuracy
    0.6
    

    【讨论】:

    • 我正在使用这个函数来计算精度。我的方法正确吗? pred = y_pred.ravel()
    • @AtheerAbdullatif 是的,这也是计算准确性的正确方法。 sklearn.metrics.accuracy_score(y_true, pred)np.mean(pred == y_true) 完全相同。
    • 我真的很感激。我现在正在寻找方法来定义我想使用 ROC 的准确度阈值,但问题是在将预测标签传递给 ROC 之前我仍然需要定义一个阈值:pred = y_pred.ravel()
    • @AtheerAbdullatif 我认为您应该通过实验找到它,即尝试不同的阈值,看看哪一个对验证/测试数据最有效。
    猜你喜欢
    • 2019-06-13
    • 2012-04-15
    • 1970-01-01
    • 1970-01-01
    • 2018-09-18
    • 2020-04-23
    • 2019-07-14
    • 2019-05-26
    • 1970-01-01
    相关资源
    最近更新 更多