【问题标题】:tf.newaxis operation in TensorFlowTensorFlow 中的 tf.newaxis 操作
【发布时间】:2021-01-26 10:19:48
【问题描述】:

x_train = x_train[..., tf.newaxis].astype("float32")

x_test = x_test[..., tf.newaxis].astype("float32")

有人能解释一下tf.newaxis 的工作原理吗?

我在文档中找到了一个简短的提及

https://www.tensorflow.org/api_docs/python/tf/strided_slice

但我无法正确理解。

【问题讨论】:

  • 我正在使用 tensorflow 2
  • 你明白我评论的逻辑了吗?
  • 对不起,我是新来的堆栈溢出,但完成了

标签: python tensorflow


【解决方案1】:

检查这个例子:

a = tf.constant([100])
print(a.shape) ## (1)
expanded_1 = tf.expand_dims(a,axis=1)
print(expanded_1.shape) ## (1,1)
expanded_2 = a[:, tf.newaxis]
print(expanded_2.shape) ## (1,1)

类似于expand_dims(),增加了一个新的轴。

如果要在张量的开头添加新轴,请使用

expanded_2 = a[tf.newaxis, :]

否则(最后)

expanded_2 = a[:,tf.newaxis]

【讨论】:

    【解决方案2】:

    您还可以使用tf.newaxis 为张量添加维度,同时保持相同的信息。

    # Create a rank 2 tensor (2 dimensions)
    rank_2_tensor = tf.constant([[10, 7],
                                 [3, 4]])
    
    print("dimension: ", rank_2_tensor.ndim)
    print("shape    : ", rank_2_tensor.shape)
    

    输出:

    尺寸:2
    形状:TensorShape([2, 2])

    # Add an extra dimension (to the end)
    rank_3_tensor = rank_2_tensor[..., tf.newaxis] 
    # in Python "..." means "all dimensions prior to"
    
    print("dimension: ", rank_3_tensor .ndim)
    print("shape    : ", rank_3_tensor .shape)
    

    输出:

    尺寸:3
    形状:TensorShape([2, 2, 1])

    您可以使用 tf.expand_dims() 实现相同的目的。

    rank_new_3_tensor = tf.expand_dims(rank_2_tensor, axis=-1) # "-1" means last axis
    print("dimension: ", rank_new_3_tensor .ndim)
    print("shape    : ", rank_new_3_tensor .shape)
    

    输出:

    尺寸:3
    形状:TensorShape([2, 2, 1])

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2017-10-21
      • 2018-06-28
      • 1970-01-01
      • 2020-09-01
      • 2019-02-15
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多