Tzktz's picture
Upload 7664 files
6fc683c verified
import torch.nn as nn
import torch
from fairseq.modules.quant_noise import quant_noise
from fairseq.modules import MultiheadAttention
from fairseq.modules.transformer_layer import TransformerDecoderLayerBase
from fairseq.models.transformer import TransformerDecoderBase, TransformerDecoder
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.distributed import fsdp_wrap
from fairseq.models.transformer import TransformerConfig
class UniLMMultiheadAttention(MultiheadAttention):
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0, qn_block_size=8):
super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention, encoder_decoder_attention=encoder_decoder_attention, q_noise=q_noise, qn_block_size=qn_block_size)
self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=True), q_noise, qn_block_size)
self.k_proj.bias = nn.Parameter(torch.zeros_like(self.k_proj.bias, requires_grad=False))
class UniLMDecoderLayer(TransformerDecoderLayerBase):
def build_self_attention(
self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
):
return UniLMMultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
dropout=cfg.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not cfg.cross_self_attention,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
class UniLMDecoderBase(TransformerDecoderBase):
def build_decoder_layer(self, cfg, no_encoder_attn=False):
layer = UniLMDecoderLayer(cfg, no_encoder_attn)
checkpoint = cfg.checkpoint_activations
if checkpoint:
offload_to_cpu = cfg.offload_activations
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
class UniLMDecoder(UniLMDecoderBase):
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
self.args = args
super().__init__(
TransformerConfig.from_namespace(args),
dictionary,
embed_tokens,
no_encoder_attn=no_encoder_attn,
output_projection=output_projection,
)
def build_output_projection(self, args, dictionary, embed_tokens):
super().build_output_projection(
TransformerConfig.from_namespace(args), dictionary, embed_tokens
)
def build_decoder_layer(self, args, no_encoder_attn=False):
return super().build_decoder_layer(
TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn
)