File size: 17,194 Bytes
5e9bd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 |
"""
Implementation of "Attention is All You Need" and of
subsequent transformer based architectures
"""
import torch
import torch.nn as nn
from onmt.decoders.decoder import DecoderBase
from onmt.modules import MultiHeadedAttention, AverageAttention
from onmt.modules.position_ffn import PositionwiseFeedForward
from onmt.modules.position_ffn import ActivationFunction
from onmt.utils.misc import sequence_mask
class TransformerDecoderLayerBase(nn.Module):
def __init__(
self,
d_model,
heads,
d_ff,
dropout,
attention_dropout,
self_attn_type="scaled-dot",
max_relative_positions=0,
aan_useffn=False,
full_context_alignment=False,
alignment_heads=0,
pos_ffn_activation_fn=ActivationFunction.relu,
):
"""
Args:
d_model (int): the dimension of keys/values/queries in
:class:`MultiHeadedAttention`, also the input size of
the first-layer of the :class:`PositionwiseFeedForward`.
heads (int): the number of heads for MultiHeadedAttention.
d_ff (int): the second-layer of the
:class:`PositionwiseFeedForward`.
dropout (float): dropout in residual, self-attn(dot) and
feed-forward
attention_dropout (float): dropout in context_attn (and
self-attn(avg))
self_attn_type (string): type of self-attention scaled-dot,
average
max_relative_positions (int):
Max distance between inputs in relative positions
representations
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
full_context_alignment (bool):
whether enable an extra full context decoder forward for
alignment
alignment_heads (int):
N. of cross attention heads to use for alignment guiding
pos_ffn_activation_fn (ActivationFunction):
activation function choice for PositionwiseFeedForward layer
"""
super(TransformerDecoderLayerBase, self).__init__()
if self_attn_type == "scaled-dot":
self.self_attn = MultiHeadedAttention(
heads,
d_model,
dropout=attention_dropout,
max_relative_positions=max_relative_positions,
)
elif self_attn_type == "average":
self.self_attn = AverageAttention(
d_model, dropout=attention_dropout, aan_useffn=aan_useffn
)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout,
pos_ffn_activation_fn
)
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
self.drop = nn.Dropout(dropout)
self.full_context_alignment = full_context_alignment
self.alignment_heads = alignment_heads
def forward(self, *args, **kwargs):
"""Extend `_forward` for (possibly) multiple decoder pass:
Always a default (future masked) decoder forward pass,
Possibly a second future aware decoder pass for joint learn
full context alignement, :cite:`garg2019jointly`.
Args:
* All arguments of _forward.
with_align (bool): whether return alignment attention.
Returns:
(FloatTensor, FloatTensor, FloatTensor or None):
* output ``(batch_size, T, model_dim)``
* top_attn ``(batch_size, T, src_len)``
* attn_align ``(batch_size, T, src_len)`` or None
"""
with_align = kwargs.pop("with_align", False)
output, attns = self._forward(*args, **kwargs)
top_attn = attns[:, 0, :, :].contiguous()
attn_align = None
if with_align:
if self.full_context_alignment:
# return _, (B, Q_len, K_len)
_, attns = self._forward(*args, **kwargs, future=True)
if self.alignment_heads > 0:
attns = attns[:, : self.alignment_heads, :, :].contiguous()
# layer average attention across heads, get ``(B, Q, K)``
# Case 1: no full_context, no align heads -> layer avg baseline
# Case 2: no full_context, 1 align heads -> guided align
# Case 3: full_context, 1 align heads -> full cte guided align
attn_align = attns.mean(dim=1)
return output, top_attn, attn_align
def update_dropout(self, dropout, attention_dropout):
self.self_attn.update_dropout(attention_dropout)
self.feed_forward.update_dropout(dropout)
self.drop.p = dropout
def _forward(self, *args, **kwargs):
raise NotImplementedError
def _compute_dec_mask(self, tgt_pad_mask, future):
tgt_len = tgt_pad_mask.size(-1)
if not future: # apply future_mask, result mask in (B, T, T)
future_mask = torch.ones(
[tgt_len, tgt_len],
device=tgt_pad_mask.device,
dtype=torch.uint8,
)
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
# BoolTensor was introduced in pytorch 1.2
try:
future_mask = future_mask.bool()
except AttributeError:
pass
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
else: # only mask padding, result mask in (B, 1, T)
dec_mask = tgt_pad_mask
return dec_mask
def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
if isinstance(self.self_attn, MultiHeadedAttention):
return self.self_attn(
inputs_norm,
inputs_norm,
inputs_norm,
mask=dec_mask,
layer_cache=layer_cache,
attn_type="self",
)
elif isinstance(self.self_attn, AverageAttention):
return self.self_attn(
inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step
)
else:
raise ValueError(
f"self attention {type(self.self_attn)} not supported"
)
class TransformerDecoderLayer(TransformerDecoderLayerBase):
"""Transformer Decoder layer block in Pre-Norm style.
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
providing better converge speed and performance. This is also the actual
implementation in tensor2tensor and also avalable in fairseq.
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
.. mermaid::
graph LR
%% "*SubLayer" can be self-attn, src-attn or feed forward block
A(input) --> B[Norm]
B --> C["*SubLayer"]
C --> D[Drop]
D --> E((+))
A --> E
E --> F(out)
"""
def __init__(
self,
d_model,
heads,
d_ff,
dropout,
attention_dropout,
self_attn_type="scaled-dot",
max_relative_positions=0,
aan_useffn=False,
full_context_alignment=False,
alignment_heads=0,
pos_ffn_activation_fn=ActivationFunction.relu,
):
"""
Args:
See TransformerDecoderLayerBase
"""
super(TransformerDecoderLayer, self).__init__(
d_model,
heads,
d_ff,
dropout,
attention_dropout,
self_attn_type,
max_relative_positions,
aan_useffn,
full_context_alignment,
alignment_heads,
pos_ffn_activation_fn=pos_ffn_activation_fn,
)
self.context_attn = MultiHeadedAttention(
heads, d_model, dropout=attention_dropout
)
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
def update_dropout(self, dropout, attention_dropout):
super(TransformerDecoderLayer, self).update_dropout(
dropout, attention_dropout
)
self.context_attn.update_dropout(attention_dropout)
def _forward(
self,
inputs,
memory_bank,
src_pad_mask,
tgt_pad_mask,
layer_cache=None,
step=None,
future=False,
):
"""A naive forward pass for transformer decoder.
# T: could be 1 in the case of stepwise decoding or tgt_len
Args:
inputs (FloatTensor): ``(batch_size, T, model_dim)``
memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
src_pad_mask (bool): ``(batch_size, 1, src_len)``
tgt_pad_mask (bool): ``(batch_size, 1, T)``
layer_cache (dict or None): cached layer info when stepwise decode
step (int or None): stepwise decoding counter
future (bool): If set True, do not apply future_mask.
Returns:
(FloatTensor, FloatTensor):
* output ``(batch_size, T, model_dim)``
* attns ``(batch_size, head, T, src_len)``
"""
dec_mask = None
if inputs.size(1) > 1:
# masking is necessary when sequence length is greater than one
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
inputs_norm = self.layer_norm_1(inputs)
query, _ = self._forward_self_attn(
inputs_norm, dec_mask, layer_cache, step
)
query = self.drop(query) + inputs
query_norm = self.layer_norm_2(query)
mid, attns = self.context_attn(
memory_bank,
memory_bank,
query_norm,
mask=src_pad_mask,
layer_cache=layer_cache,
attn_type="context",
)
output = self.feed_forward(self.drop(mid) + query)
return output, attns
class TransformerDecoderBase(DecoderBase):
def __init__(self, d_model, copy_attn, alignment_layer):
super(TransformerDecoderBase, self).__init__()
# Decoder State
self.state = {}
# previously, there was a GlobalAttention module here for copy
# attention. But it was never actually used -- the "copy" attention
# just reuses the context attention.
self._copy = copy_attn
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.alignment_layer = alignment_layer
@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.dec_layers,
opt.dec_rnn_size,
opt.heads,
opt.transformer_ff,
opt.copy_attn,
opt.self_attn_type,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout,
embeddings,
opt.max_relative_positions,
opt.aan_useffn,
opt.full_context_alignment,
opt.alignment_layer,
alignment_heads=opt.alignment_heads,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
)
def init_state(self, src, memory_bank, enc_hidden):
"""Initialize decoder state."""
self.state["src"] = src
self.state["cache"] = None
def map_state(self, fn):
def _recursive_map(struct, batch_dim=0):
for k, v in struct.items():
if v is not None:
if isinstance(v, dict):
_recursive_map(v)
else:
struct[k] = fn(v, batch_dim)
if self.state["src"] is not None:
self.state["src"] = fn(self.state["src"], 1)
if self.state["cache"] is not None:
_recursive_map(self.state["cache"])
def detach_state(self):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def update_dropout(self, dropout, attention_dropout):
self.embeddings.update_dropout(dropout)
for layer in self.transformer_layers:
layer.update_dropout(dropout, attention_dropout)
class TransformerDecoder(TransformerDecoderBase):
"""The Transformer decoder from "Attention is All You Need".
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
.. mermaid::
graph BT
A[input]
B[multi-head self-attn]
BB[multi-head src-attn]
C[feed forward]
O[output]
A --> B
B --> BB
BB --> C
C --> O
Args:
num_layers (int): number of decoder layers.
d_model (int): size of the model
heads (int): number of heads
d_ff (int): size of the inner FF layer
copy_attn (bool): if using a separate copy attention
self_attn_type (str): type of self-attention scaled-dot, average
dropout (float): dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float): dropout in context_attn (and self-attn(avg))
embeddings (onmt.modules.Embeddings):
embeddings to use, should have positional encodings
max_relative_positions (int):
Max distance between inputs in relative positions representations
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
full_context_alignment (bool):
whether enable an extra full context decoder forward for alignment
alignment_layer (int): N° Layer to supervise with for alignment guiding
alignment_heads (int):
N. of cross attention heads to use for alignment guiding
"""
def __init__(
self,
num_layers,
d_model,
heads,
d_ff,
copy_attn,
self_attn_type,
dropout,
attention_dropout,
max_relative_positions,
aan_useffn,
full_context_alignment,
alignment_layer,
alignment_heads,
pos_ffn_activation_fn=ActivationFunction.relu,
):
super(TransformerDecoder, self).__init__(
d_model, copy_attn, alignment_layer
)
self.transformer_layers = nn.ModuleList(
[
TransformerDecoderLayer(
d_model,
heads,
d_ff,
dropout,
attention_dropout,
self_attn_type=self_attn_type,
max_relative_positions=max_relative_positions,
aan_useffn=aan_useffn,
full_context_alignment=full_context_alignment,
alignment_heads=alignment_heads,
pos_ffn_activation_fn=pos_ffn_activation_fn,
)
for i in range(num_layers)
]
)
def detach_state(self):
self.state["src"] = self.state["src"].detach()
def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs):
"""Decode, possibly stepwise."""
if step == 0:
self._init_cache(memory_bank)
batch_size, src_len, src_dim = memory_bank.size()
device = memory_bank.device
if src_pad_mask is None:
src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device)
output = tgt_emb
batch_size, tgt_len, tgt_dim = tgt_emb.size()
if tgt_pad_mask is None:
tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device)
future = kwargs.pop("future", False)
with_align = kwargs.pop("with_align", False)
attn_aligns = []
hiddens = []
for i, layer in enumerate(self.transformer_layers):
layer_cache = (
self.state["cache"]["layer_{}".format(i)]
if step is not None
else None
)
output, attn, attn_align = layer(
output,
memory_bank,
src_pad_mask,
tgt_pad_mask,
layer_cache=layer_cache,
step=step,
with_align=with_align,
future=future
)
hiddens.append(output)
if attn_align is not None:
attn_aligns.append(attn_align)
output = self.layer_norm(output) # (B, L, D)
attns = {"std": attn}
if self._copy:
attns["copy"] = attn
if with_align:
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
# TODO change the way attns is returned dict => list or tuple (onnx)
return output, attns, hiddens
def _init_cache(self, memory_bank):
self.state["cache"] = {}
for i, layer in enumerate(self.transformer_layers):
layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None}
self.state["cache"]["layer_{}".format(i)] = layer_cache
|