|
|
|
|
|
|
|
|
|
|
|
import math |
|
from collections import namedtuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from fairseq import options, utils |
|
from fairseq.modules import ( |
|
AdaptiveSoftmax, |
|
LayerNorm, |
|
MultiheadAttention, |
|
PositionalEmbedding, |
|
) |
|
|
|
|
|
EncoderOut = namedtuple( |
|
"TransformerEncoderOut", |
|
[ |
|
"encoder_out", |
|
"encoder_padding_mask", |
|
"encoder_embedding", |
|
"encoder_states", |
|
], |
|
) |
|
|
|
|
|
class TransformerEncoderEmbedding(nn.Module): |
|
""" Encoder Embedding + Positional Embedding """ |
|
|
|
def __init__(self, args, embed_tokens): |
|
super().__init__() |
|
self.dropout = args.dropout |
|
self.max_source_positions = args.max_source_positions |
|
self.embed_tokens = embed_tokens |
|
if isinstance(embed_tokens, nn.ModuleList): |
|
self.padding_idx = embed_tokens[0].padding_idx |
|
embed_dim = sum(e.embedding_dim for e in embed_tokens) |
|
else: |
|
self.padding_idx = embed_tokens.padding_idx |
|
embed_dim = embed_tokens.embedding_dim |
|
self.embed_scale = math.sqrt(embed_dim) |
|
self.embed_positions = ( |
|
PositionalEmbedding( |
|
args.max_source_positions, |
|
embed_dim, |
|
self.padding_idx, |
|
learned=args.encoder_learned_pos, |
|
) |
|
if not args.no_token_positional_embeddings |
|
else None |
|
) |
|
if getattr(args, "layernorm_embedding", False): |
|
self.layernorm_embedding = LayerNorm(embed_dim) |
|
else: |
|
self.layernorm_embedding = None |
|
|
|
def forward(self, input): |
|
|
|
src_tokens = input[0] |
|
prev_output_tokens = input[2] |
|
if isinstance(self.embed_tokens, nn.ModuleList): |
|
x_embed_list = [] |
|
for embed_tokens_part in self.embed_tokens: |
|
x_embed_list.append(embed_tokens_part(src_tokens)) |
|
|
|
embedded = torch.cat(x_embed_list, dim=-1) |
|
else: |
|
embedded = self.embed_tokens(src_tokens) |
|
x = embed = self.embed_scale * embedded |
|
if self.embed_positions is not None: |
|
x = embed + self.embed_positions(src_tokens) |
|
if self.layernorm_embedding: |
|
x = self.layernorm_embedding(x) |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
|
x = x.transpose(0, 1) |
|
|
|
|
|
encoder_padding_mask = src_tokens.eq(self.padding_idx) |
|
return (x, encoder_padding_mask, prev_output_tokens) |
|
|
|
|
|
class TransformerEncoderLayerNorm(nn.Module): |
|
""" |
|
Layer norm at the the end of all encoder layers if |
|
args.encoder_enormalize_before = True |
|
""" |
|
|
|
def __init__(self, args, embed_dim): |
|
super().__init__() |
|
if args.encoder_normalize_before: |
|
self.layer_norm = LayerNorm(embed_dim) |
|
else: |
|
self.layer_norm = None |
|
|
|
def forward(self, input): |
|
x = input[0] |
|
encoder_padding_mask = input[1] |
|
prev_output_tokens = input[2] |
|
if self.layer_norm: |
|
x = self.layer_norm(x) |
|
|
|
return (x, encoder_padding_mask, prev_output_tokens) |
|
|
|
|
|
class TransformerDecoderEmbedding(nn.Module): |
|
""" Decoder Embedding + Positional Embedding """ |
|
|
|
def __init__(self, args, embed_tokens): |
|
super().__init__() |
|
self.dropout = args.dropout |
|
self.share_input_output_embed = args.share_decoder_input_output_embed |
|
input_embed_dim = ( |
|
sum(e.embedding_dim for e in embed_tokens) |
|
if isinstance(embed_tokens, nn.ModuleList) |
|
else embed_tokens.embedding_dim |
|
) |
|
embed_dim = args.decoder_embed_dim |
|
self.output_embed_dim = args.decoder_output_dim |
|
|
|
padding_idx = ( |
|
embed_tokens[0].padding_idx |
|
if isinstance(embed_tokens, nn.ModuleList) |
|
else embed_tokens.padding_idx |
|
) |
|
self.max_target_positions = args.max_target_positions |
|
|
|
self.embed_tokens = embed_tokens |
|
self.embed_scale = math.sqrt(embed_dim) |
|
|
|
self.project_in_dim = ( |
|
Linear(input_embed_dim, embed_dim, bias=False) |
|
if embed_dim != input_embed_dim |
|
else None |
|
) |
|
|
|
self.embed_positions = ( |
|
PositionalEmbedding( |
|
args.max_target_positions, |
|
embed_dim, |
|
padding_idx, |
|
learned=args.decoder_learned_pos, |
|
) |
|
if not args.no_token_positional_embeddings |
|
else None |
|
) |
|
|
|
def forward(self, input): |
|
mt_task = False |
|
if isinstance(input, tuple): |
|
if len(input) == 3: |
|
encoder_out = input[0] |
|
encoder_padding_mask = input[1] |
|
prev_output_tokens = input[2] |
|
incremental_state = None |
|
mt_task = True |
|
else: |
|
|
|
prev_output_tokens = input[0] |
|
|
|
encoder_out = None |
|
encoder_padding_mask = None |
|
incremental_state = None |
|
|
|
else: |
|
prev_output_tokens = input |
|
encoder_out = None |
|
encoder_padding_mask = None |
|
incremental_state = None |
|
|
|
positions = ( |
|
self.embed_positions( |
|
prev_output_tokens, |
|
incremental_state=incremental_state, |
|
) |
|
if self.embed_positions is not None |
|
else None |
|
) |
|
|
|
if incremental_state is not None: |
|
prev_output_tokens = prev_output_tokens[:, -1:] |
|
if positions is not None: |
|
positions = positions[:, -1:] |
|
|
|
|
|
|
|
if isinstance(self.embed_tokens, nn.ModuleList): |
|
x_embed_list = [] |
|
for embed_tokens_part in self.embed_tokens: |
|
x_embed_list.append(embed_tokens_part(prev_output_tokens)) |
|
|
|
x = self.embed_scale * torch.cat(x_embed_list, dim=-1) |
|
else: |
|
x = self.embed_scale * self.embed_tokens(prev_output_tokens) |
|
|
|
if self.project_in_dim is not None: |
|
x = self.project_in_dim(x) |
|
|
|
if positions is not None: |
|
x += positions |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
if mt_task: |
|
return (x, encoder_out, encoder_padding_mask) |
|
return x |
|
|
|
|
|
class TransformerDecoderOutputLayer(nn.Module): |
|
def __init__(self, args, embed_tokens, dictionary): |
|
super().__init__() |
|
self.share_input_output_embed = args.share_decoder_input_output_embed |
|
self.embed_tokens = embed_tokens |
|
self.output_embed_dim = args.decoder_output_dim |
|
embed_dim = args.decoder_embed_dim |
|
|
|
self.project_out_dim = ( |
|
Linear(embed_dim, self.output_embed_dim, bias=False) |
|
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights |
|
else None |
|
) |
|
self.adaptive_softmax = None |
|
if args.adaptive_softmax_cutoff is not None: |
|
assert not isinstance(embed_tokens, nn.ModuleList) |
|
self.adaptive_softmax = AdaptiveSoftmax( |
|
len(dictionary), |
|
self.output_embed_dim, |
|
options.eval_str_list(args.adaptive_softmax_cutoff, type=int), |
|
dropout=args.adaptive_softmax_dropout, |
|
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, |
|
factor=args.adaptive_softmax_factor, |
|
tie_proj=args.tie_adaptive_proj, |
|
) |
|
elif not self.share_input_output_embed: |
|
self.embed_tokens = nn.Parameter( |
|
torch.Tensor(len(dictionary), self.output_embed_dim) |
|
) |
|
nn.init.normal_( |
|
self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5 |
|
) |
|
|
|
if args.decoder_normalize_before and not getattr( |
|
args, "no_decoder_final_norm", False |
|
): |
|
self.layer_norm = LayerNorm(embed_dim) |
|
else: |
|
self.layer_norm = None |
|
|
|
def forward(self, input, apply_final_proj=True): |
|
if isinstance(input, tuple): |
|
x = input[0] |
|
else: |
|
x = input |
|
|
|
if self.layer_norm: |
|
x = self.layer_norm(x) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
if self.project_out_dim is not None: |
|
x = self.project_out_dim(x) |
|
if apply_final_proj: |
|
x = self.output_layer(x) |
|
return x |
|
|
|
def output_layer(self, features, **kwargs): |
|
"""Project features to the vocabulary size.""" |
|
if self.adaptive_softmax is None: |
|
|
|
if self.share_input_output_embed: |
|
if isinstance(self.embed_tokens, nn.ModuleList): |
|
output = None |
|
for i, emb in enumerate(self.embed_tokens): |
|
sidx = i * emb.embedding_dim |
|
eidx = (i + 1) * emb.embedding_dim |
|
if output is None: |
|
output = F.linear(features[:, :, sidx:eidx], emb.weight) |
|
else: |
|
output += F.linear(features[:, :, sidx:eidx], emb.weight) |
|
|
|
return output |
|
else: |
|
return F.linear(features, self.embed_tokens.weight) |
|
else: |
|
return F.linear(features, self.embed_tokens) |
|
else: |
|
return features |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
"""Encoder layer block. |
|
In the original paper each operation (multi-head attention or FFN) is |
|
postprocessed with: `dropout -> add residual -> layernorm`. In the |
|
tensor2tensor code they suggest that learning is more robust when |
|
preprocessing each layer with layernorm and postprocessing with: |
|
`dropout -> add residual`. We default to the approach in the paper, but the |
|
tensor2tensor approach can be enabled by setting |
|
*args.encoder_normalize_before* to ``True``. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
""" |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.embed_dim = args.encoder_embed_dim |
|
self.self_attn = MultiheadAttention( |
|
self.embed_dim, |
|
args.encoder_attention_heads, |
|
dropout=args.attention_dropout, |
|
self_attention=True, |
|
) |
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim) |
|
self.dropout = args.dropout |
|
self.activation_fn = utils.get_activation_fn( |
|
activation=getattr(args, "activation_fn", "relu") |
|
) |
|
self.activation_dropout = getattr(args, "activation_dropout", 0) |
|
if self.activation_dropout == 0: |
|
|
|
self.activation_dropout = getattr(args, "relu_dropout", 0) |
|
self.normalize_before = args.encoder_normalize_before |
|
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) |
|
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) |
|
self.final_layer_norm = LayerNorm(self.embed_dim) |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
""" |
|
Rename layer norm states from `...layer_norms.0.weight` to |
|
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to |
|
`...final_layer_norm.weight` |
|
""" |
|
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} |
|
for old, new in layer_norm_map.items(): |
|
for m in ("weight", "bias"): |
|
k = "{}.layer_norms.{}.{}".format(name, old, m) |
|
if k in state_dict: |
|
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] |
|
del state_dict[k] |
|
|
|
def forward(self, input): |
|
""" |
|
Args: |
|
input (Tuple): |
|
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
|
input[1] (ByteTensor/FloatTensor): encoder padding mask - |
|
binary ByteTensor of shape `(batch, src_len)` where padding elements |
|
are indicated by ``1``. |
|
input[2] (LongTensor): previous decoder outputs of shape |
|
`(batch, tgt_len)`, for teacher forcing) |
|
Returns: |
|
output (Tuple): |
|
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` |
|
output[1] (ByteTensor/FloatTensor): encoder padding mask |
|
output[2] (LongTensor): previous decoder outputs |
|
""" |
|
x = input[0] |
|
encoder_padding_mask = input[1] |
|
prev_output_tokens = input[2] |
|
residual = x |
|
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) |
|
x, _ = self.self_attn( |
|
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask |
|
) |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
x = residual + x |
|
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) |
|
|
|
residual = x |
|
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) |
|
x = self.activation_fn(self.fc1(x)) |
|
x = F.dropout(x, p=self.activation_dropout, training=self.training) |
|
x = self.fc2(x) |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
x = residual + x |
|
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) |
|
return (x, encoder_padding_mask, prev_output_tokens) |
|
|
|
def maybe_layer_norm(self, layer_norm, x, before=False, after=False): |
|
assert before ^ after |
|
if after ^ self.normalize_before: |
|
return layer_norm(x) |
|
else: |
|
return x |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
"""Decoder layer block. |
|
|
|
In the original paper each operation (multi-head attention, encoder |
|
attention or FFN) is postprocessed with: `dropout -> add residual -> |
|
layernorm`. In the tensor2tensor code they suggest that learning is more |
|
robust when preprocessing each layer with layernorm and postprocessing with: |
|
`dropout -> add residual`. We default to the approach in the paper, but the |
|
tensor2tensor approach can be enabled by setting |
|
*args.decoder_normalize_before* to ``True``. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs |
|
(default: False). |
|
""" |
|
|
|
def __init__( |
|
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False |
|
): |
|
super().__init__() |
|
self.embed_dim = args.decoder_embed_dim |
|
self.self_attn = MultiheadAttention( |
|
embed_dim=self.embed_dim, |
|
num_heads=args.decoder_attention_heads, |
|
dropout=args.attention_dropout, |
|
add_bias_kv=add_bias_kv, |
|
add_zero_attn=add_zero_attn, |
|
self_attention=True, |
|
) |
|
self.dropout = args.dropout |
|
self.activation_fn = utils.get_activation_fn( |
|
activation=getattr(args, "activation_fn", "relu") |
|
) |
|
self.activation_dropout = getattr(args, "activation_dropout", 0) |
|
if self.activation_dropout == 0: |
|
|
|
self.activation_dropout = getattr(args, "relu_dropout", 0) |
|
self.normalize_before = args.decoder_normalize_before |
|
|
|
|
|
|
|
|
|
export = getattr(args, "char_inputs", False) |
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) |
|
|
|
if no_encoder_attn: |
|
self.encoder_attn = None |
|
self.encoder_attn_layer_norm = None |
|
else: |
|
self.encoder_attn = MultiheadAttention( |
|
self.embed_dim, |
|
args.decoder_attention_heads, |
|
kdim=getattr(args, "encoder_embed_dim", None), |
|
vdim=getattr(args, "encoder_embed_dim", None), |
|
dropout=args.attention_dropout, |
|
encoder_decoder_attention=True, |
|
) |
|
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) |
|
|
|
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) |
|
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) |
|
|
|
self.final_layer_norm = LayerNorm(self.embed_dim, export=export) |
|
self.need_attn = True |
|
|
|
self.onnx_trace = False |
|
|
|
def prepare_for_onnx_export_(self): |
|
self.onnx_trace = True |
|
|
|
def forward(self, input): |
|
""" |
|
Args: |
|
input (Tuple): |
|
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
|
input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)` |
|
input[2] (ByteTensor/FloatTensor): encoder padding mask - |
|
binary ByteTensor of shape `(batch, src_len)` where padding elements |
|
are indicated by ``1``. |
|
Returns: |
|
output (Tuple): |
|
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` |
|
output[1] (ByteTensor/FloatTensor): encoder padding mask |
|
output[2] (LongTensor): previous decoder outputs |
|
""" |
|
|
|
mt_task = False |
|
if isinstance(input, tuple): |
|
x = input[0] |
|
encoder_out = input[1] |
|
encoder_padding_mask = input[2] |
|
incremental_state = None |
|
mt_task = True |
|
else: |
|
x = input |
|
encoder_out = None |
|
encoder_padding_mask = None |
|
incremental_state = None |
|
|
|
if incremental_state is None: |
|
self_attn_mask = self.buffered_future_mask(x) |
|
else: |
|
self_attn_mask = None |
|
|
|
|
|
|
|
prev_self_attn_state = None |
|
prev_attn_state = None |
|
self_attn_padding_mask = None |
|
|
|
residual = x |
|
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) |
|
if prev_self_attn_state is not None: |
|
if incremental_state is None: |
|
incremental_state = {} |
|
prev_key, prev_value = prev_self_attn_state |
|
saved_state = {"prev_key": prev_key, "prev_value": prev_value} |
|
self.self_attn._set_input_buffer(incremental_state, saved_state) |
|
x, attn = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=self_attn_padding_mask, |
|
incremental_state=incremental_state, |
|
need_weights=False, |
|
attn_mask=self_attn_mask, |
|
) |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
x = residual + x |
|
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) |
|
|
|
if self.encoder_attn is not None: |
|
residual = x |
|
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) |
|
if prev_attn_state is not None: |
|
if incremental_state is None: |
|
incremental_state = {} |
|
prev_key, prev_value = prev_attn_state |
|
saved_state = {"prev_key": prev_key, "prev_value": prev_value} |
|
self.encoder_attn._set_input_buffer(incremental_state, saved_state) |
|
x, attn = self.encoder_attn( |
|
query=x, |
|
key=encoder_out, |
|
value=encoder_out, |
|
key_padding_mask=encoder_padding_mask, |
|
incremental_state=incremental_state, |
|
static_kv=True, |
|
need_weights=(not self.training and self.need_attn), |
|
) |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
x = residual + x |
|
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) |
|
|
|
residual = x |
|
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) |
|
x = self.activation_fn(self.fc1(x)) |
|
x = F.dropout(x, p=self.activation_dropout, training=self.training) |
|
x = self.fc2(x) |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
x = residual + x |
|
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) |
|
|
|
if mt_task: |
|
return (x, encoder_out, encoder_padding_mask) |
|
return x |
|
|
|
def buffered_future_mask(self, tensor): |
|
dim = tensor.size(0) |
|
if ( |
|
not hasattr(self, "_future_mask") |
|
or self._future_mask is None |
|
or self._future_mask.device != tensor.device |
|
): |
|
self._future_mask = torch.triu( |
|
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 |
|
) |
|
if self._future_mask.size(0) < dim: |
|
self._future_mask = torch.triu( |
|
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 |
|
) |
|
return self._future_mask[:dim, :dim] |
|
|
|
def maybe_layer_norm(self, layer_norm, x, before=False, after=False): |
|
assert before ^ after |
|
if after ^ self.normalize_before: |
|
return layer_norm(x) |
|
else: |
|
return x |
|
|
|
def make_generation_fast_(self, need_attn=False, **kwargs): |
|
self.need_attn = need_attn |
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx): |
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) |
|
nn.init.constant_(m.weight[padding_idx], 0) |
|
return m |
|
|
|
|
|
def Linear(in_features, out_features, bias=True): |
|
m = nn.Linear(in_features, out_features, bias) |
|
nn.init.xavier_uniform_(m.weight) |
|
if bias: |
|
nn.init.constant_(m.bias, 0.0) |
|
return m |
|
|