【问题标题】:Reorder axis in TensorFlow Keras layerTensorFlow Keras 层中的重新排序轴
【发布时间】:2021-12-18 16:36:34
【问题描述】:

我正在构建一个模型,该模型沿第一个非批处理轴对数据应用随机洗牌,应用一系列 Conv1D,然后应用洗牌的逆。不幸的是,tf.gather 层弄乱了批次维度None,我不知道为什么。

下面是一个例子。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

dim = 90
input_img = keras.Input(shape=(dim, 4))

# Get random shuffle order
order = layers.Lambda(lambda x: tf.random.shuffle(tf.range(x)))(dim)

# Apply shuffle
tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))(input_img, order)

model = keras.models.Model(
   inputs=[input_img],
   outputs=tensor,
)

这里总结如下:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)           [(None, 90, 4)]           0         
_________________________________________________________________
lambda_51 (Lambda)           (90, 90, 4)               0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

而我希望lambda_51 的输出形状为(None, 90, 4)

【问题讨论】:

    标签: python tensorflow keras tensorflow2.0 tf.keras


    【解决方案1】:

    当您将input_imgorder 传递给tensor 层时,尝试将它们包装到一个列表中。

    这样tensor层就变成了:

    tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))([input_img, order])
    

    还有你的总结:

    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_2 (InputLayer)         [(None, 90, 4)]           0         
    _________________________________________________________________
    lambda_3 (Lambda)            (None, 90, 4)             0         
    =================================================================
    Total params: 0
    Trainable params: 0
    Non-trainable params: 0
    

    【讨论】:

      猜你喜欢
      • 2018-08-24
      • 2017-08-02
      • 1970-01-01
      • 1970-01-01
      • 2017-12-12
      • 2021-09-04
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多