aliabd
full working demo
d5175d3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
MultiheadAttention,
PositionalEmbedding,
)
EncoderOut = namedtuple(
"TransformerEncoderOut",
[
"encoder_out", # T x B x C
"encoder_padding_mask", # B x T
"encoder_embedding", # B x T x C
"encoder_states", # List[T x B x C]
],
)
class TransformerEncoderEmbedding(nn.Module):
""" Encoder Embedding + Positional Embedding """
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
if isinstance(embed_tokens, nn.ModuleList):
self.padding_idx = embed_tokens[0].padding_idx
embed_dim = sum(e.embedding_dim for e in embed_tokens)
else:
self.padding_idx = embed_tokens.padding_idx
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
args.max_source_positions,
embed_dim,
self.padding_idx,
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
def forward(self, input):
# embed tokens and positions
src_tokens = input[0]
prev_output_tokens = input[2]
if isinstance(self.embed_tokens, nn.ModuleList):
x_embed_list = []
for embed_tokens_part in self.embed_tokens:
x_embed_list.append(embed_tokens_part(src_tokens))
embedded = torch.cat(x_embed_list, dim=-1)
else:
embedded = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * embedded
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding:
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
return (x, encoder_padding_mask, prev_output_tokens)
class TransformerEncoderLayerNorm(nn.Module):
"""
Layer norm at the the end of all encoder layers if
args.encoder_enormalize_before = True
"""
def __init__(self, args, embed_dim):
super().__init__()
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, input):
x = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
if self.layer_norm:
x = self.layer_norm(x)
# keeping track of the incremental_state is not supported yet
return (x, encoder_padding_mask, prev_output_tokens)
class TransformerDecoderEmbedding(nn.Module):
""" Decoder Embedding + Positional Embedding """
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = (
sum(e.embedding_dim for e in embed_tokens)
if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.embedding_dim
)
embed_dim = args.decoder_embed_dim
self.output_embed_dim = args.decoder_output_dim
padding_idx = (
embed_tokens[0].padding_idx
if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.padding_idx
)
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
args.max_target_positions,
embed_dim,
padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
def forward(self, input):
mt_task = False
if isinstance(input, tuple):
if len(input) == 3:
encoder_out = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
incremental_state = None # Hardcoding to avoid passing of None objects
mt_task = True
else:
# HACK for now, need to fix (TODO sidgoyal)
prev_output_tokens = input[0]
# discard "src_lengths"
encoder_out = None
encoder_padding_mask = None
incremental_state = None
else:
prev_output_tokens = input
encoder_out = None
encoder_padding_mask = None
incremental_state = None
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
if isinstance(self.embed_tokens, nn.ModuleList):
x_embed_list = []
for embed_tokens_part in self.embed_tokens:
x_embed_list.append(embed_tokens_part(prev_output_tokens))
x = self.embed_scale * torch.cat(x_embed_list, dim=-1)
else:
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if mt_task:
return (x, encoder_out, encoder_padding_mask)
return x
class TransformerDecoderOutputLayer(nn.Module):
def __init__(self, args, embed_tokens, dictionary):
super().__init__()
self.share_input_output_embed = args.share_decoder_input_output_embed
self.embed_tokens = embed_tokens
self.output_embed_dim = args.decoder_output_dim
embed_dim = args.decoder_embed_dim
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
if args.adaptive_softmax_cutoff is not None:
assert not isinstance(embed_tokens, nn.ModuleList)
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_tokens = nn.Parameter(
torch.Tensor(len(dictionary), self.output_embed_dim)
)
nn.init.normal_(
self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5
)
if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False
):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, input, apply_final_proj=True):
if isinstance(input, tuple):
x = input[0]
else:
x = input
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if apply_final_proj:
x = self.output_layer(x)
return x
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
if isinstance(self.embed_tokens, nn.ModuleList):
output = None
for i, emb in enumerate(self.embed_tokens):
sidx = i * emb.embedding_dim
eidx = (i + 1) * emb.embedding_dim
if output is None:
output = F.linear(features[:, :, sidx:eidx], emb.weight)
else:
output += F.linear(features[:, :, sidx:eidx], emb.weight)
return output
else:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_tokens)
else:
return features
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
)
self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, input):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
input[2] (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing)
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
x = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn(
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return (x, encoder_padding_mask, prev_output_tokens)
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True,
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
)
self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def forward(self, input):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
input[2] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
# Note: incremental state is not yet supported
mt_task = False
if isinstance(input, tuple):
x = input[0]
encoder_out = input[1]
encoder_padding_mask = input[2]
incremental_state = None
mt_task = True
else:
x = input
encoder_out = None
encoder_padding_mask = None
incremental_state = None
if incremental_state is None:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
# TODO: add back prev_self_attn_state, prev_attn_state,
# self_attn_padding_mask
prev_self_attn_state = None
prev_attn_state = None
self_attn_padding_mask = None
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if mt_task:
return (x, encoder_out, encoder_padding_mask)
return x
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m