|
""" |
|
Implementation of "Attention is All You Need" and of |
|
subsequent transformer based architectures |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from onmt.decoders.decoder import DecoderBase |
|
from onmt.modules import MultiHeadedAttention, AverageAttention |
|
from onmt.modules.position_ffn import PositionwiseFeedForward |
|
from onmt.modules.position_ffn import ActivationFunction |
|
from onmt.utils.misc import sequence_mask |
|
|
|
|
|
class TransformerDecoderLayerBase(nn.Module): |
|
def __init__( |
|
self, |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
self_attn_type="scaled-dot", |
|
max_relative_positions=0, |
|
aan_useffn=False, |
|
full_context_alignment=False, |
|
alignment_heads=0, |
|
pos_ffn_activation_fn=ActivationFunction.relu, |
|
): |
|
""" |
|
Args: |
|
d_model (int): the dimension of keys/values/queries in |
|
:class:`MultiHeadedAttention`, also the input size of |
|
the first-layer of the :class:`PositionwiseFeedForward`. |
|
heads (int): the number of heads for MultiHeadedAttention. |
|
d_ff (int): the second-layer of the |
|
:class:`PositionwiseFeedForward`. |
|
dropout (float): dropout in residual, self-attn(dot) and |
|
feed-forward |
|
attention_dropout (float): dropout in context_attn (and |
|
self-attn(avg)) |
|
self_attn_type (string): type of self-attention scaled-dot, |
|
average |
|
max_relative_positions (int): |
|
Max distance between inputs in relative positions |
|
representations |
|
aan_useffn (bool): Turn on the FFN layer in the AAN decoder |
|
full_context_alignment (bool): |
|
whether enable an extra full context decoder forward for |
|
alignment |
|
alignment_heads (int): |
|
N. of cross attention heads to use for alignment guiding |
|
pos_ffn_activation_fn (ActivationFunction): |
|
activation function choice for PositionwiseFeedForward layer |
|
|
|
""" |
|
super(TransformerDecoderLayerBase, self).__init__() |
|
|
|
if self_attn_type == "scaled-dot": |
|
self.self_attn = MultiHeadedAttention( |
|
heads, |
|
d_model, |
|
dropout=attention_dropout, |
|
max_relative_positions=max_relative_positions, |
|
) |
|
elif self_attn_type == "average": |
|
self.self_attn = AverageAttention( |
|
d_model, dropout=attention_dropout, aan_useffn=aan_useffn |
|
) |
|
|
|
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, |
|
pos_ffn_activation_fn |
|
) |
|
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) |
|
self.drop = nn.Dropout(dropout) |
|
self.full_context_alignment = full_context_alignment |
|
self.alignment_heads = alignment_heads |
|
|
|
def forward(self, *args, **kwargs): |
|
"""Extend `_forward` for (possibly) multiple decoder pass: |
|
Always a default (future masked) decoder forward pass, |
|
Possibly a second future aware decoder pass for joint learn |
|
full context alignement, :cite:`garg2019jointly`. |
|
|
|
Args: |
|
* All arguments of _forward. |
|
with_align (bool): whether return alignment attention. |
|
|
|
Returns: |
|
(FloatTensor, FloatTensor, FloatTensor or None): |
|
|
|
* output ``(batch_size, T, model_dim)`` |
|
* top_attn ``(batch_size, T, src_len)`` |
|
* attn_align ``(batch_size, T, src_len)`` or None |
|
""" |
|
with_align = kwargs.pop("with_align", False) |
|
output, attns = self._forward(*args, **kwargs) |
|
top_attn = attns[:, 0, :, :].contiguous() |
|
attn_align = None |
|
if with_align: |
|
if self.full_context_alignment: |
|
|
|
_, attns = self._forward(*args, **kwargs, future=True) |
|
|
|
if self.alignment_heads > 0: |
|
attns = attns[:, : self.alignment_heads, :, :].contiguous() |
|
|
|
|
|
|
|
|
|
attn_align = attns.mean(dim=1) |
|
return output, top_attn, attn_align |
|
|
|
def update_dropout(self, dropout, attention_dropout): |
|
self.self_attn.update_dropout(attention_dropout) |
|
self.feed_forward.update_dropout(dropout) |
|
self.drop.p = dropout |
|
|
|
def _forward(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
def _compute_dec_mask(self, tgt_pad_mask, future): |
|
tgt_len = tgt_pad_mask.size(-1) |
|
if not future: |
|
future_mask = torch.ones( |
|
[tgt_len, tgt_len], |
|
device=tgt_pad_mask.device, |
|
dtype=torch.uint8, |
|
) |
|
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) |
|
|
|
try: |
|
future_mask = future_mask.bool() |
|
except AttributeError: |
|
pass |
|
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) |
|
else: |
|
dec_mask = tgt_pad_mask |
|
return dec_mask |
|
|
|
def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step): |
|
if isinstance(self.self_attn, MultiHeadedAttention): |
|
return self.self_attn( |
|
inputs_norm, |
|
inputs_norm, |
|
inputs_norm, |
|
mask=dec_mask, |
|
layer_cache=layer_cache, |
|
attn_type="self", |
|
) |
|
elif isinstance(self.self_attn, AverageAttention): |
|
return self.self_attn( |
|
inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step |
|
) |
|
else: |
|
raise ValueError( |
|
f"self attention {type(self.self_attn)} not supported" |
|
) |
|
|
|
|
|
class TransformerDecoderLayer(TransformerDecoderLayerBase): |
|
"""Transformer Decoder layer block in Pre-Norm style. |
|
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style, |
|
providing better converge speed and performance. This is also the actual |
|
implementation in tensor2tensor and also avalable in fairseq. |
|
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`. |
|
|
|
.. mermaid:: |
|
|
|
graph LR |
|
%% "*SubLayer" can be self-attn, src-attn or feed forward block |
|
A(input) --> B[Norm] |
|
B --> C["*SubLayer"] |
|
C --> D[Drop] |
|
D --> E((+)) |
|
A --> E |
|
E --> F(out) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
self_attn_type="scaled-dot", |
|
max_relative_positions=0, |
|
aan_useffn=False, |
|
full_context_alignment=False, |
|
alignment_heads=0, |
|
pos_ffn_activation_fn=ActivationFunction.relu, |
|
): |
|
""" |
|
Args: |
|
See TransformerDecoderLayerBase |
|
""" |
|
super(TransformerDecoderLayer, self).__init__( |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
self_attn_type, |
|
max_relative_positions, |
|
aan_useffn, |
|
full_context_alignment, |
|
alignment_heads, |
|
pos_ffn_activation_fn=pos_ffn_activation_fn, |
|
) |
|
self.context_attn = MultiHeadedAttention( |
|
heads, d_model, dropout=attention_dropout |
|
) |
|
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) |
|
|
|
def update_dropout(self, dropout, attention_dropout): |
|
super(TransformerDecoderLayer, self).update_dropout( |
|
dropout, attention_dropout |
|
) |
|
self.context_attn.update_dropout(attention_dropout) |
|
|
|
def _forward( |
|
self, |
|
inputs, |
|
memory_bank, |
|
src_pad_mask, |
|
tgt_pad_mask, |
|
layer_cache=None, |
|
step=None, |
|
future=False, |
|
): |
|
"""A naive forward pass for transformer decoder. |
|
|
|
# T: could be 1 in the case of stepwise decoding or tgt_len |
|
|
|
Args: |
|
inputs (FloatTensor): ``(batch_size, T, model_dim)`` |
|
memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)`` |
|
src_pad_mask (bool): ``(batch_size, 1, src_len)`` |
|
tgt_pad_mask (bool): ``(batch_size, 1, T)`` |
|
layer_cache (dict or None): cached layer info when stepwise decode |
|
step (int or None): stepwise decoding counter |
|
future (bool): If set True, do not apply future_mask. |
|
|
|
Returns: |
|
(FloatTensor, FloatTensor): |
|
|
|
* output ``(batch_size, T, model_dim)`` |
|
* attns ``(batch_size, head, T, src_len)`` |
|
|
|
""" |
|
dec_mask = None |
|
|
|
if inputs.size(1) > 1: |
|
|
|
dec_mask = self._compute_dec_mask(tgt_pad_mask, future) |
|
|
|
inputs_norm = self.layer_norm_1(inputs) |
|
|
|
query, _ = self._forward_self_attn( |
|
inputs_norm, dec_mask, layer_cache, step |
|
) |
|
|
|
query = self.drop(query) + inputs |
|
|
|
query_norm = self.layer_norm_2(query) |
|
mid, attns = self.context_attn( |
|
memory_bank, |
|
memory_bank, |
|
query_norm, |
|
mask=src_pad_mask, |
|
layer_cache=layer_cache, |
|
attn_type="context", |
|
) |
|
output = self.feed_forward(self.drop(mid) + query) |
|
|
|
return output, attns |
|
|
|
|
|
class TransformerDecoderBase(DecoderBase): |
|
def __init__(self, d_model, copy_attn, alignment_layer): |
|
super(TransformerDecoderBase, self).__init__() |
|
|
|
|
|
self.state = {} |
|
|
|
|
|
|
|
|
|
self._copy = copy_attn |
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) |
|
|
|
self.alignment_layer = alignment_layer |
|
|
|
@classmethod |
|
def from_opt(cls, opt, embeddings): |
|
"""Alternate constructor.""" |
|
return cls( |
|
opt.dec_layers, |
|
opt.dec_rnn_size, |
|
opt.heads, |
|
opt.transformer_ff, |
|
opt.copy_attn, |
|
opt.self_attn_type, |
|
opt.dropout[0] if type(opt.dropout) is list else opt.dropout, |
|
opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, |
|
embeddings, |
|
opt.max_relative_positions, |
|
opt.aan_useffn, |
|
opt.full_context_alignment, |
|
opt.alignment_layer, |
|
alignment_heads=opt.alignment_heads, |
|
pos_ffn_activation_fn=opt.pos_ffn_activation_fn, |
|
) |
|
|
|
def init_state(self, src, memory_bank, enc_hidden): |
|
"""Initialize decoder state.""" |
|
self.state["src"] = src |
|
self.state["cache"] = None |
|
|
|
def map_state(self, fn): |
|
def _recursive_map(struct, batch_dim=0): |
|
for k, v in struct.items(): |
|
if v is not None: |
|
if isinstance(v, dict): |
|
_recursive_map(v) |
|
else: |
|
struct[k] = fn(v, batch_dim) |
|
|
|
if self.state["src"] is not None: |
|
self.state["src"] = fn(self.state["src"], 1) |
|
if self.state["cache"] is not None: |
|
_recursive_map(self.state["cache"]) |
|
|
|
def detach_state(self): |
|
raise NotImplementedError |
|
|
|
def forward(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
def update_dropout(self, dropout, attention_dropout): |
|
self.embeddings.update_dropout(dropout) |
|
for layer in self.transformer_layers: |
|
layer.update_dropout(dropout, attention_dropout) |
|
|
|
|
|
class TransformerDecoder(TransformerDecoderBase): |
|
"""The Transformer decoder from "Attention is All You Need". |
|
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17` |
|
|
|
.. mermaid:: |
|
|
|
graph BT |
|
A[input] |
|
B[multi-head self-attn] |
|
BB[multi-head src-attn] |
|
C[feed forward] |
|
O[output] |
|
A --> B |
|
B --> BB |
|
BB --> C |
|
C --> O |
|
|
|
|
|
Args: |
|
num_layers (int): number of decoder layers. |
|
d_model (int): size of the model |
|
heads (int): number of heads |
|
d_ff (int): size of the inner FF layer |
|
copy_attn (bool): if using a separate copy attention |
|
self_attn_type (str): type of self-attention scaled-dot, average |
|
dropout (float): dropout in residual, self-attn(dot) and feed-forward |
|
attention_dropout (float): dropout in context_attn (and self-attn(avg)) |
|
embeddings (onmt.modules.Embeddings): |
|
embeddings to use, should have positional encodings |
|
max_relative_positions (int): |
|
Max distance between inputs in relative positions representations |
|
aan_useffn (bool): Turn on the FFN layer in the AAN decoder |
|
full_context_alignment (bool): |
|
whether enable an extra full context decoder forward for alignment |
|
alignment_layer (int): N° Layer to supervise with for alignment guiding |
|
alignment_heads (int): |
|
N. of cross attention heads to use for alignment guiding |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_layers, |
|
d_model, |
|
heads, |
|
d_ff, |
|
copy_attn, |
|
self_attn_type, |
|
dropout, |
|
attention_dropout, |
|
max_relative_positions, |
|
aan_useffn, |
|
full_context_alignment, |
|
alignment_layer, |
|
alignment_heads, |
|
pos_ffn_activation_fn=ActivationFunction.relu, |
|
): |
|
super(TransformerDecoder, self).__init__( |
|
d_model, copy_attn, alignment_layer |
|
) |
|
|
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
TransformerDecoderLayer( |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
self_attn_type=self_attn_type, |
|
max_relative_positions=max_relative_positions, |
|
aan_useffn=aan_useffn, |
|
full_context_alignment=full_context_alignment, |
|
alignment_heads=alignment_heads, |
|
pos_ffn_activation_fn=pos_ffn_activation_fn, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
def detach_state(self): |
|
self.state["src"] = self.state["src"].detach() |
|
|
|
def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs): |
|
"""Decode, possibly stepwise.""" |
|
if step == 0: |
|
self._init_cache(memory_bank) |
|
|
|
batch_size, src_len, src_dim = memory_bank.size() |
|
device = memory_bank.device |
|
if src_pad_mask is None: |
|
src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device) |
|
output = tgt_emb |
|
batch_size, tgt_len, tgt_dim = tgt_emb.size() |
|
if tgt_pad_mask is None: |
|
tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device) |
|
|
|
future = kwargs.pop("future", False) |
|
with_align = kwargs.pop("with_align", False) |
|
attn_aligns = [] |
|
hiddens = [] |
|
|
|
for i, layer in enumerate(self.transformer_layers): |
|
layer_cache = ( |
|
self.state["cache"]["layer_{}".format(i)] |
|
if step is not None |
|
else None |
|
) |
|
output, attn, attn_align = layer( |
|
output, |
|
memory_bank, |
|
src_pad_mask, |
|
tgt_pad_mask, |
|
layer_cache=layer_cache, |
|
step=step, |
|
with_align=with_align, |
|
future=future |
|
) |
|
hiddens.append(output) |
|
if attn_align is not None: |
|
attn_aligns.append(attn_align) |
|
|
|
output = self.layer_norm(output) |
|
|
|
attns = {"std": attn} |
|
if self._copy: |
|
attns["copy"] = attn |
|
if with_align: |
|
attns["align"] = attn_aligns[self.alignment_layer] |
|
|
|
|
|
|
|
return output, attns, hiddens |
|
|
|
def _init_cache(self, memory_bank): |
|
self.state["cache"] = {} |
|
for i, layer in enumerate(self.transformer_layers): |
|
layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None} |
|
self.state["cache"]["layer_{}".format(i)] = layer_cache |
|
|
|
|