【问题标题】:Tracking tensor shape at graph creation time在图创建时跟踪张量形状
【发布时间】:2018-03-02 08:44:33
【问题描述】:

在某些情况下,张量流似乎能够在图创建时检查张量的值,而在其他情况下则失败。

>>> shape = [constant([2])[0], 3]
>>> reshape([1,2,3,4,5,6], shape)
<tf.Tensor 'Reshape_13:0' shape=(2, 3) dtype=int32>
>>> zeros(shape)
<tf.Tensor 'zeros_2:0' shape=(?, 3) dtype=float32>

在上面的例子中,reshape() 可以看到作为 shape 传入的张量的值为 2,结果输出的 shape 为 (2,3) 但 zeros() 不能,静态 shape 为 ( ?,3)。造成差异的原因是什么?

我的同事发布了Determining tensor shapes at time of graph creation in TensorFlow,它基于相同的潜在问题,但他提出了一个稍微不同的问题,即如何最好地使用 tensorflow 来解决这类问题,而我的问题是关于为什么 tensorflow 会这样方法。是bug吗?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    TD;DR:

    • tf.reshape 可以推断输出的形状,但tf.zeros 不能;
    • shape 支持整数(如静态/确定)和张量(如动态/不定)这两个函数。

    代码更具体、更清晰:

    shape = [tf.constant([2])[0], tf.constant([3])[0]]
    print(tf.reshape([1,2,3,4,5,6], shape))  
    # Tensor("Reshape:0", shape=(?, ?), dtype=int32)
    print(tf.zeros(shape))  
    # Tensor("zeros:0", shape=(?, ?), dtype=float32)
    

    还有这个:

    shape = [tf.constant([5])[0], 3]
    print tf.reshape([1,2,3,4,5,6], shape)  
    # Tensor("Reshape:0", shape=(2, 3), dtype=int32)
    # This will cause an InvalidArgumentError at running time!
    

    当使用Tensor(如tf.constant([2])[0])作为shape 创建另一个Tensor(如tf.zeros(shape))时,图形在创建时总是不确定的。但是,tf.reshape() 不同。它可以使用输入的形状和给定​​的形状(静态部分)推断输出的形状。

    在您的代码中,3 是一个静态整数,输入的形状是给定的([6]); (2, 3) 的形状实际上是通过推断获得的,而不是提供的。这可以在代码的第二部分证明。虽然我给了tf.constant([5]),但形状并没有改变。 (图创建时没有错误,但在运行时引发了错误!)

    【讨论】:

    • 是的,我得出了同样的结论,reshape 是使用第一个参数的形状来推断输出形状应该是什么。
    猜你喜欢
    • 2018-08-08
    • 2013-06-26
    • 1970-01-01
    • 1970-01-01
    • 2019-09-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-12-13
    相关资源
    最近更新 更多