【问题标题】:TensorFlow: questions regarding tf.argmax() and tf.equal()TensorFlow:关于 tf.argmax() 和 tf.equal() 的问题
【发布时间】:2017-06-02 04:35:25
【问题描述】:

我正在学习 TensorFlow,构建一个多层感知器模型。我正在研究一些示例,例如:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/multilayer_perceptron.ipynb

然后我在下面的代码中有一些问题:

def multilayer_perceptron(x, weights, biases):
    :
    :

pred = multilayer_perceptron(x, weights, biases)
    :
    :

with tf.Session() as sess:
    sess.run(init)
         :
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    print ("Accuracy:", accuracy.eval({x: X_test, y: y_test_onehot}))

我想知道tf.argmax(prod,1)tf.argmax(y,1) 到底是什么意思和返回(类型和值)? correct_prediction 是变量而不是实际值吗?

最后,我们如何从tf会话中得到y_test_prediction数组(输入数据为X_test时的预测结果)?非常感谢!

【问题讨论】:

    标签: tensorflow neural-network deep-learning


    【解决方案1】:
    tf.argmax(input, axis=None, name=None, dimension=None)
    

    返回张量轴上具有最大值的索引。

    input 是一个张量,axis 描述了输入张量的哪个轴要减少。对于向量,使用axis = 0。

    对于您的具体情况,让我们使用两个数组来演示这一点

    pred = np.array([[31, 23,  4, 24, 27, 34],
                    [18,  3, 25,  0,  6, 35],
                    [28, 14, 33, 22, 20,  8],
                    [13, 30, 21, 19,  7,  9],
                    [16,  1, 26, 32,  2, 29],
                    [17, 12,  5, 11, 10, 15]])
    
    y = np.array([[31, 23,  4, 24, 27, 34],
                    [18,  3, 25,  0,  6, 35],
                    [28, 14, 33, 22, 20,  8],
                    [13, 30, 21, 19,  7,  9],
                    [16,  1, 26, 32,  2, 29],
                    [17, 12,  5, 11, 10, 15]])
    

    评估tf.argmax(pred, 1) 给出一个张量,其评估将给出array([5, 5, 2, 1, 3, 0])

    评估tf.argmax(y, 1) 给出一个张量,其评估将给出array([5, 5, 2, 1, 3, 0])

    tf.equal(x, y, name=None) takes two tensors(x and y) as inputs and returns the truth value of (x == y) element-wise. 
    

    按照我们的示例,tf.equal(tf.argmax(pred, 1),tf.argmax(y, 1)) 返回一个张量,其评估将给出array(1,1,1,1,1,1)

    correct_prediction 是一个张量,其评估将给出 0 和 1 的一维数组

    y_test_prediction可以通过执行pred = tf.argmax(logits, 1)得到

    可以通过以下链接访问 tf.argmax 和 tf.equal 的文档。

    tf.argmax() https://www.tensorflow.org/api_docs/python/math_ops/sequence_comparison_and_indexing#argmax

    tf.equal() https://www.tensorflow.org/versions/master/api_docs/python/control_flow_ops/comparison_operators#equal

    【讨论】:

    • logits 是该代码中的pred 吗?我的意思是 logits 是模型的结果,例如out = tf.matmul(fully_conn_layer, _weights['out']) + _biases['out'](我有一个 CNN 实现 AlexNet 模型)?所以我不明白你是如何获得pred = tf.argmax(logits, 1)
    • 因为在这种情况下,logits 是“样本/图片”所属的概率(对于每个类别)。所以为了得到pred。对于“样本/图片”,您获得最大概率,因此您使用 tf.argmax
    【解决方案2】:

    阅读文档:

    tf.argmax

    返回张量轴上具有最大值的索引。

    tf.equal

    按元素返回 (x == y) 的真值。

    tf.cast

    将张量转换为新类型。

    tf.reduce_mean

    计算张量维度上元素的平均值。


    现在您可以轻松解释它的作用。您的 y 是 one-hot 编码的,因此它有一个 1 并且所有其他都是零。您的 pred 代表类的概率。因此 argmax 找到最佳预测和正确值的位置。然后检查它们是否相同。

    所以现在您的correct_prediction 是一个真/假值向量,其大小等于您要预测的实例数。您将其转换为浮点数并取平均值。


    实际上这部分在 Evaluate the Model 部分的TF tutorial 中有很好的解释

    【讨论】:

    • 在真实模型中这个过程,我的意思是 tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) 的过程是针对每个样本或每个批次完成的还是每个时代?最终值是多少?
    【解决方案3】:

    tf.argmax(输入,轴=无,名称=无,维度=无)

    返回张量轴上具有最大值的索引。

    对于特定情况,它接收 pred 作为参数,因为它是 input1 作为 axis。轴描述了输入张量的哪个轴要减少。对于向量,使用axis = 0。

    示例:给定列表[2.11,1.0021,3.99,4.32] argmax 将返回3,这是最大值的索引。


    correct_prediction 是稍后将评估的张量。它不是一个常规的 python 变量。它包含稍后计算该值的必要信息。 对于这种特定情况,它将是另一个张量 accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 的一部分,并将由 evalaccuracy.eval({x: X_test, y: y_test_onehot}) 上进行评估。


    y_test_prediction 应该是您的 correct_prediction 张量。

    【讨论】:

    • 为什么会返回3?其中最大值是 4.32?
    • @AyodhyankitPaul 索引4.32,最高值是3,从0开始(2.11)
    • 知道了,它返回索引而不是值。
    【解决方案4】:

    对于没有太多时间了解tf.argmax的人:

    x = np.array([[1, 9, 3],[4, 5, 6]])

    tf.argmax(x, axis = 0)

    输出:

    [array([1, 0, 1], dtype=int64)]
    

    tf.argmax(x,axis = 1)

    输出:

    [array([1, 2], dtype=int64)]
    

    source

    【讨论】:

      猜你喜欢
      • 2018-10-29
      • 1970-01-01
      • 2019-11-18
      • 1970-01-01
      • 2018-06-29
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多