【问题标题】:How to support masking in custom tf.keras.layers.Layer如何在自定义 tf.keras.layers.Layer 中支持遮罩
【发布时间】:2019-08-06 04:36:20
【问题描述】:

我正在实现一个需要支持屏蔽的自定义 tf.keras.layers.Layer

考虑以下场景

embedded = tf.keras.layer.Embedding(input_dim=vocab_size + 1, 
                                    output_dim=n_dims, 
                                    mask_zero=True)
x = MyCustomKerasLayers(embedded)

现在根据文档

mask_zero: 输入值 0 是否是一个特殊的“填充”值,应该被屏蔽掉。这在使用可能采用可变长度输入的循环层时很有用。 如果为 True,则模型中的所有后续层都需要支持屏蔽,否则将引发异常。如果 mask_zero 设置为 True,则索引 0 不能在词汇表中使用(input_dim 应该等于词汇表的大小 + 1)。

我想知道,这是什么意思?查看TensorFlow's custom layers guidetf.keras.layer.Layer 文档,不清楚应该做些什么来支持屏蔽

  1. 如何支持屏蔽?

  2. 如何从过去的图层访问蒙版?

  3. 假设输入 (batch, time, channels) 或 `(batch, time) 掩码看起来会有所不同吗?它们的形状是什么?

  4. 如何将其传递到下一层?

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:
    1. 要支持屏蔽,应在自定义层内实现compute_mask 方法

    2. 要访问掩码,只需在call 方法中添加参数mask 作为第二个位置参数,即可访问(例如call(self, inputs, mask=None)

    3. 这个猜不出来,是layer的before负责计算mask

    4. 一旦你实现了compute_mask,将掩码传递到下一层会自动发生 - 不包括模型子类化的情况,在这种情况下,由你来计算掩码并传递它们。

    例子:

    class MyCustomKerasLayers(tf.keras.layers.Layer):
        def __init__(self, .......):
            ...
    
        def compute_mask(self, inputs, mask=None):
            # Just pass the received mask from previous layer, to the next layer or 
            # manipulate it if this layer changes the shape of the input
            return mask
    
        def call(self, input, mask=None):
            # using 'mask' you can access the mask passed from the previous layer
    

    注意这个例子只是传递了掩码,如果图层输出的形状与接收到的不同,您应该在compute_mask中相应地更改掩码以传递正确的形状

    编辑

    现在解释也包含在tf.keras masking and padding documentation中。

    【讨论】:

    • 但如果输出形状与输入形状不同怎么办?
    • 所以你将不得不更新掩码,你应该在compute_mask 中操作它并返回具有新形状的新掩码
    • 我不知道新掩码在 compute_mask 中应该是什么样子。例如,如果这是一个 rnn 问题,假设我们将所有序列填充到 sen_len。假设输入具有形状 (batch_size, seq_len, feature_dim) 和输出形状 (batch_size, seq_len, feature_dim/2)。 compue_mask 前后的掩码是什么样的。我可以在评论中问这样的问题吗?还是我应该打开一个新问题?
    • 掩码比输入低一级,所以如果您有(batch_size, seq_len, channels),掩码通常是(batch_size, seq_len) - 所以在您的具体示例中,您不需要更改掩码因为保留了相同的无关信息索引(沿第二个轴)
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-03-08
    • 1970-01-01
    • 2015-03-23
    • 2011-11-24
    • 2016-12-27
    相关资源
    最近更新 更多