【问题标题】:Dropping layers in Transformer models (PyTorch / HuggingFace)在 Transformer 模型中删除层 (PyTorch / HuggingFace)
【发布时间】:2021-12-18 11:07:01
【问题描述】:

我在 Transformer 模型中的层下降中遇到了这个有趣的paper,我实际上正在尝试实现它。但是,我想知道执行“层删除”的好做法是什么。

我有几个想法,但不知道去这里最干净/最安全的方式是什么:

  • 屏蔽不需要的层(某种修剪)
  • 将所需层复制到新模型中

如果有人之前已经这样做过或有建议,我会全力以赴!

干杯

【问题讨论】:

标签: python nlp pytorch huggingface-transformers


【解决方案1】:

我认为最安全的方法之一就是跳过前向传递中的给定层。

例如,假设您使用的是BERT,并且您在配置中添加了以下条目:

config.active_layers = [False, True] * 6  # using a 12 layers model

然后你可以像下面这样修改BertEncoder类:

class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            
            ########### MAGIC HERE #############
            if not self.config.active_layers[i]:
                continue
            
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                if use_cache:
                    logger.warning(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

目前您可能需要使用新的Encoder 层编写您的特殊BERT 类。但是,您应该能够从huggingface 提供的预训练模型中加载权重。

BertEncoder 代码取自here

【讨论】:

  • 感谢 Luca 的建议,这对我来说看起来很安全。我明天试试,让你知道! :)
  • 嘿@Luca,我实际上得到了 HF 家伙的回答,他们的做法似乎与您的建议相似:) github.com/huggingface/transformers/blob/…
猜你喜欢
  • 2023-01-28
  • 2022-11-03
  • 2019-04-20
  • 2021-08-22
  • 2019-03-04
  • 1970-01-01
  • 2020-10-07
  • 2021-01-17
  • 2020-04-19
相关资源
最近更新 更多