【问题标题】:Want to predict model output with some dummy inputs想要用一些虚拟输入来预测模型输出
【发布时间】:2022-06-10 18:00:15
【问题描述】:

我正在运行一个 Keras 模型,我想在编译和训练模型之前生成带有一些虚拟输入的模型输出。在这里,我分享一些我认为对理解我的模型很重要的模型代码, 完整代码请访问to this colab file,也可以查看官方keras code here

class ShiftViTModel(keras.Model):
"""The ShiftViT Model.

Args:
    data_augmentation (keras.Model): A data augmentation model.
    projected_dim (int): The dimension to which the patches of the image are
        projected.
    patch_size (int): The patch size of the images.
    num_shift_blocks_per_stages (list[int]): A list of all the number of shit
        blocks per stage.
    epsilon (float): The epsilon constant.
    mlp_dropout_rate (float): The dropout rate used in the MLP block.
    stochastic_depth_rate (float): The maximum drop rate probability.
    num_div (int): The number of divisions of the channesl of the feature
        map. Defaults to 12.
    shift_pixel (int): The number of pixel to shift. Default to 1.
    mlp_expand_ratio (int): The ratio with which the initial mlp dense layer
        is expanded to. Defaults to 2.
"""

def __init__(
    self,
    data_augmentation,
    projected_dim,
    patch_size,
    num_shift_blocks_per_stages,
    epsilon,
    mlp_dropout_rate,
    stochastic_depth_rate,
    num_div=12,
    shift_pixel=1,
    mlp_expand_ratio=2,
    **kwargs,
):
    super().__init__(**kwargs)
    self.data_augmentation = data_augmentation
    self.patch_projection = layers.Conv2D(
        filters=projected_dim,
        kernel_size=patch_size,
        strides=patch_size,
        padding="same",
    )
    self.stages = list()
    for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages):
        if index == len(num_shift_blocks_per_stages) - 1:
            # This is the last stage, do not use the patch merge here.
            is_merge = False
        else:
            is_merge = True
        # Build the stages.
        self.stages.append(
            StackedShiftBlocks(
                epsilon=epsilon,
                mlp_dropout_rate=mlp_dropout_rate,
                num_shift_blocks=num_shift_blocks,
                stochastic_depth_rate=stochastic_depth_rate,
                is_merge=is_merge,
                num_div=num_div,
                shift_pixel=shift_pixel,
                mlp_expand_ratio=mlp_expand_ratio,
            )
        )
    self.global_avg_pool = layers.GlobalAveragePooling2D()

def get_config(self):
    config = super().get_config()
    config.update(
        {
            "data_augmentation": self.data_augmentation,
            "patch_projection": self.patch_projection,
            "stages": self.stages,
            "global_avg_pool": self.global_avg_pool,
        }
    )
    return config

def _calculate_loss(self, data, training=False):
    (images, labels) = data

    # Augment the images
    augmented_images = self.data_augmentation(images, training=training)

    # Create patches and project the pathces.
    projected_patches = self.patch_projection(augmented_images)

    # Pass through the stages
    x = projected_patches
    for stage in self.stages:
        x = stage(x, training=training)

    # Get the logits.
    logits = self.global_avg_pool(x)

    # Calculate the loss and return it.
    total_loss = self.compiled_loss(labels, logits)
    return total_loss, labels, logits

def train_step(self, inputs):
    with tf.GradientTape() as tape:
        total_loss, labels, logits = self._calculate_loss(
            data=inputs, training=True
        )

    # Apply gradients.
    train_vars = [
        self.data_augmentation.trainable_variables,
        self.patch_projection.trainable_variables,
        self.global_avg_pool.trainable_variables,
    ]
    train_vars = train_vars + [stage.trainable_variables for stage in self.stages]

    # Optimize the gradients.
    grads = tape.gradient(total_loss, train_vars)
    trainable_variable_list = []
    for (grad, var) in zip(grads, train_vars):
        for g, v in zip(grad, var):
            trainable_variable_list.append((g, v))
    self.optimizer.apply_gradients(trainable_variable_list)

    # Update the metrics
    self.compiled_metrics.update_state(labels, logits)
    return {m.name: m.result() for m in self.metrics}

def test_step(self, data):
    _, labels, logits = self._calculate_loss(data=data, training=False)

    # Update the metrics
    self.compiled_metrics.update_state(labels, logits)
    return {m.name: m.result() for m in self.metrics}

第二块

model = ShiftViTModel(
data_augmentation=get_augmentation_model(),
projected_dim=config.projected_dim,
patch_size=config.patch_size,
num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,
epsilon=config.epsilon,
mlp_dropout_rate=config.mlp_dropout_rate,
stochastic_depth_rate=config.stochastic_depth_rate,
num_div=config.num_div,
shift_pixel=config.shift_pixel,
mlp_expand_ratio=config.mlp_expand_ratio, 
)

实际上我正在尝试像这样从上述模型中产生输出

dummy_inputs = tf.ones((2, 32, 32, 3))
outputs = model(dummy_inputs, training=False)
print(outputs.shape)

但它会产生一个错误

未实现tf.keras.Model.call():如果你打算创建一个 Model 使用功能 API,请提供 inputsoutputs 论据。否则,使用覆盖的call() 子类Model 方法。

【问题讨论】:

    标签: python tensorflow machine-learning keras deep-learning


    【解决方案1】:

    call 方法未实现,如果我们需要使用虚拟数据检查模型,则在 such implementation 中是必需的。您可以在ShiftViTModel 类中实现call 方法,如下所示使用使用的层(参见train_step 方法)。

    def call(self, images):
       augmented_images = self.data_augmentation(images)
       x = self.patch_projection(augmented_images)
       logits = self.global_avg_pool(x)
       return logits
    

    现在,如果我们这样做

    model = ShiftViTModel( ... )
    x,y = next(iter(train_ds))
    print(x.shape, y.shape)
    model(x).shape
    
    (256, 32, 32, 3) (256, 1)
    TensorShape([256, 96])
    

    【讨论】:

    • 感谢您的帮助。我想在训练后保存这个模型,但是当我喜欢这个 model.save('/content/drive/MyDrive/VIT-SHIFT') 时我无法保存它说模型没有定义。我在 stackoverflow HERE 上问过这个问题,但没有得到任何有用的答案。
    • 还有一件事我想问一下super().__init__ 函数中名为self.stages = list() 的列表。这是该模型中最重要的部分,在调用函数中定义它也很重要。你在上面的调用函数中跳过了。
    • 我已经尝试将它定义为def call(self, images): augmented_images = self.data_augmentation(images) x = self.patch_projection(augmented_images) y = self.stages(x) logits = self.global_avg_pool(y) return logits,但它会产生错误
    • 请查看您另一个相关问题的答案。 stackoverflow.com/a/72496860/9215780
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-09-13
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-08-02
    • 1970-01-01
    相关资源
    最近更新 更多