【问题标题】:Tensorflow, update the Variable to have arbitrary shapeTensorflow,更新变量以具有任意形状
【发布时间】:2017-05-25 23:03:08
【问题描述】:

所以,根据documentation,我们可以使用 tf.assign 和 validate_shape=False 来改变形状。它确实改变了变量内容的形状,但是您可以从 get_shape() 获得的形状不会得到更新。例如:

>>> a = tf.Variable([1, 1, 1, 1])
>>> sess.run(tf.global_variables_initializer())
>>> tf.assign(a, [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]], validate_shape=False).eval()
array([[1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1]], dtype=int32)
>>> a.get_shape()
TensorShape([Dimension(4)])

很烦人的是,网络的后面层的形状基于这个变量的 get_shape() 值。因此,即使实际形状是正确的,Tensorflow 也会抱怨尺寸不匹配。那么关于如何更新每个变量的“相信”形状有什么想法吗?

【问题讨论】:

  • set_shape 可能吗?
  • set_shape 将尝试根据现有形状信息进行验证。我会在创建变量时设置validate_shape=False,在这种情况下,静态形状信息将完全未知(如果您愿意,可以使用set_shape 对其进行细化)。
  • 我试试看。谢谢!

标签: tensorflow


【解决方案1】:

简而言之:使用set_shape 来更新变量的静态形状。


你可以通过阅读TF FAQ来了解发生了什么:

在 TensorFlow 中,张量同时具有静态(推断)形状和 动态(真实)形状。静态形状可以使用 tf.Tensor.get_shape 方法:这个形状是从操作中推断出来的 用于创建张量的,可能是部分完整的。如果 静态形状未完全定义,张量 t 的动态形状 可以通过评估tf.shape(t)来确定。

所以静态形状没有正确推断出来,你应该给 TF 一个提示。幸运的是,同一个常见问题解答中的以下几行告诉您该怎么做:

tf.Tensor.set_shape 方法更新张量的静态形状 对象,它通常用于提供额外的形状 无法直接推断出的信息。它没有改变 张量的动态形状。

【讨论】:

    【解决方案2】:

    由于 validate_shape 设置为 false 变量的静态形状不会在图表中自动更新。一个解决方法是使用新形状(已知)手动设置它

    a = tf.Variable([1, 1, 1, 1], validate_shape=False)
    sess.run(tf.global_variables_initializer())
    new_arr_assign = np.array([[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]])
    tf.assign(a, new_arr_assign, validate_shape=False).eval(session=sess)
    a.set_shape(new_arr_assign.shape)
    a.get_shape()
    # results: TensorShape([Dimension(2), Dimension(7)])
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2016-02-12
      • 1970-01-01
      • 1970-01-01
      • 2018-09-19
      • 1970-01-01
      相关资源
      最近更新 更多