Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch | |
from fairseq.modules.quant_noise import quant_noise | |
from fairseq.modules import MultiheadAttention | |
from fairseq.modules.transformer_layer import TransformerDecoderLayerBase | |
from fairseq.models.transformer import TransformerDecoderBase, TransformerDecoder | |
from fairseq.modules.checkpoint_activations import checkpoint_wrapper | |
from fairseq.distributed import fsdp_wrap | |
from fairseq.models.transformer import TransformerConfig | |
class UniLMMultiheadAttention(MultiheadAttention): | |
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0, qn_block_size=8): | |
super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention, encoder_decoder_attention=encoder_decoder_attention, q_noise=q_noise, qn_block_size=qn_block_size) | |
self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=True), q_noise, qn_block_size) | |
self.k_proj.bias = nn.Parameter(torch.zeros_like(self.k_proj.bias, requires_grad=False)) | |
class UniLMDecoderLayer(TransformerDecoderLayerBase): | |
def build_self_attention( | |
self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False | |
): | |
return UniLMMultiheadAttention( | |
embed_dim, | |
cfg.decoder.attention_heads, | |
dropout=cfg.attention_dropout, | |
add_bias_kv=add_bias_kv, | |
add_zero_attn=add_zero_attn, | |
self_attention=not cfg.cross_self_attention, | |
q_noise=self.quant_noise, | |
qn_block_size=self.quant_noise_block_size, | |
) | |
class UniLMDecoderBase(TransformerDecoderBase): | |
def build_decoder_layer(self, cfg, no_encoder_attn=False): | |
layer = UniLMDecoderLayer(cfg, no_encoder_attn) | |
checkpoint = cfg.checkpoint_activations | |
if checkpoint: | |
offload_to_cpu = cfg.offload_activations | |
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) | |
# if we are checkpointing, enforce that FSDP always wraps the | |
# checkpointed layer, regardless of layer size | |
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 | |
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) | |
return layer | |
class UniLMDecoder(UniLMDecoderBase): | |
def __init__( | |
self, | |
args, | |
dictionary, | |
embed_tokens, | |
no_encoder_attn=False, | |
output_projection=None, | |
): | |
self.args = args | |
super().__init__( | |
TransformerConfig.from_namespace(args), | |
dictionary, | |
embed_tokens, | |
no_encoder_attn=no_encoder_attn, | |
output_projection=output_projection, | |
) | |
def build_output_projection(self, args, dictionary, embed_tokens): | |
super().build_output_projection( | |
TransformerConfig.from_namespace(args), dictionary, embed_tokens | |
) | |
def build_decoder_layer(self, args, no_encoder_attn=False): | |
return super().build_decoder_layer( | |
TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn | |
) | |