Spaces:
Runtime error
Runtime error
File size: 23,149 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 |
#!/usr/bin/env python3
from ast import literal_eval
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
@register_model("s2t_berard")
class BerardModel(FairseqEncoderDecoderModel):
"""Implementation of a model similar to https://arxiv.org/abs/1802.04200
Paper title: End-to-End Automatic Speech Translation of Audiobooks
An implementation is available in tensorflow at
https://github.com/eske/seq2seq
Relevant files in this implementation are the config
(https://github.com/eske/seq2seq/blob/master/config/LibriSpeech/AST.yaml)
and the model code
(https://github.com/eske/seq2seq/blob/master/translate/models.py).
The encoder and decoder try to be close to the original implementation.
The attention is an MLP as in Bahdanau et al.
(https://arxiv.org/abs/1409.0473).
There is no state initialization by averaging the encoder outputs.
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
parser.add_argument(
"--input-layers",
type=str,
metavar="EXPR",
help="List of linear layer dimensions. These "
"layers are applied to the input features and "
"are followed by tanh and possibly dropout.",
)
parser.add_argument(
"--dropout",
type=float,
metavar="D",
help="Dropout probability to use in the encoder/decoder. "
"Note that this parameters control dropout in various places, "
"there is no fine-grained control for dropout for embeddings "
"vs LSTM layers for example.",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="Number of encoder input channels. " "Typically value is 1.",
)
parser.add_argument(
"--conv-layers",
type=str,
metavar="EXPR",
help="List of conv layers " "(format: (channels, kernel, stride)).",
)
parser.add_argument(
"--num-blstm-layers",
type=int,
metavar="N",
help="Number of encoder bi-LSTM layers.",
)
parser.add_argument(
"--lstm-size", type=int, metavar="N", help="LSTM hidden size."
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="Embedding dimension of the decoder target tokens.",
)
parser.add_argument(
"--decoder-hidden-dim",
type=int,
metavar="N",
help="Decoder LSTM hidden dimension.",
)
parser.add_argument(
"--decoder-num-layers",
type=int,
metavar="N",
help="Number of decoder LSTM layers.",
)
parser.add_argument(
"--attention-dim",
type=int,
metavar="N",
help="Hidden layer dimension in MLP attention.",
)
parser.add_argument(
"--output-layer-dim",
type=int,
metavar="N",
help="Hidden layer dim for linear layer prior to output projection.",
)
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
@classmethod
def build_encoder(cls, args, task):
encoder = BerardEncoder(
input_layers=literal_eval(args.input_layers),
conv_layers=literal_eval(args.conv_layers),
in_channels=args.input_channels,
input_feat_per_channel=args.input_feat_per_channel,
num_blstm_layers=args.num_blstm_layers,
lstm_size=args.lstm_size,
dropout=args.dropout,
)
if getattr(args, "load_pretrained_encoder_from", None) is not None:
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
return encoder
@classmethod
def build_decoder(cls, args, task):
decoder = LSTMDecoder(
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
num_layers=args.decoder_num_layers,
hidden_size=args.decoder_hidden_dim,
dropout=args.dropout,
encoder_output_dim=2 * args.lstm_size, # bidirectional
attention_dim=args.attention_dim,
output_layer_dim=args.output_layer_dim,
)
if getattr(args, "load_pretrained_decoder_from", None) is not None:
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from
)
return decoder
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
encoder = cls.build_encoder(args, task)
decoder = cls.build_decoder(args, task)
return cls(encoder, decoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
# lprobs is a (B, T, D) tensor
lprobs.batch_first = True
return lprobs
class BerardEncoder(FairseqEncoder):
def __init__(
self,
input_layers: List[int],
conv_layers: List[Tuple[int]],
in_channels: int,
input_feat_per_channel: int,
num_blstm_layers: int,
lstm_size: int,
dropout: float,
):
"""
Args:
input_layers: list of linear layer dimensions. These layers are
applied to the input features and are followed by tanh and
possibly dropout.
conv_layers: list of conv2d layer configurations. A configuration is
a tuple (out_channels, conv_kernel_size, stride).
in_channels: number of input channels.
input_feat_per_channel: number of input features per channel. These
are speech features, typically 40 or 80.
num_blstm_layers: number of bidirectional LSTM layers.
lstm_size: size of the LSTM hidden (and cell) size.
dropout: dropout probability. Dropout can be applied after the
linear layers and LSTM layers but not to the convolutional
layers.
"""
super().__init__(None)
self.input_layers = nn.ModuleList()
in_features = input_feat_per_channel
for out_features in input_layers:
if dropout > 0:
self.input_layers.append(
nn.Sequential(
nn.Linear(in_features, out_features), nn.Dropout(p=dropout)
)
)
else:
self.input_layers.append(nn.Linear(in_features, out_features))
in_features = out_features
self.in_channels = in_channels
self.input_dim = input_feat_per_channel
self.conv_kernel_sizes_and_strides = []
self.conv_layers = nn.ModuleList()
lstm_input_dim = input_layers[-1]
for conv_layer in conv_layers:
out_channels, conv_kernel_size, conv_stride = conv_layer
self.conv_layers.append(
nn.Conv2d(
in_channels,
out_channels,
conv_kernel_size,
stride=conv_stride,
padding=conv_kernel_size // 2,
)
)
self.conv_kernel_sizes_and_strides.append((conv_kernel_size, conv_stride))
in_channels = out_channels
lstm_input_dim //= conv_stride
lstm_input_dim *= conv_layers[-1][0]
self.lstm_size = lstm_size
self.num_blstm_layers = num_blstm_layers
self.lstm = nn.LSTM(
input_size=lstm_input_dim,
hidden_size=lstm_size,
num_layers=num_blstm_layers,
dropout=dropout,
bidirectional=True,
)
self.output_dim = 2 * lstm_size # bidirectional
if dropout > 0:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = None
def forward(self, src_tokens, src_lengths=None, **kwargs):
"""
Args
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
bsz, max_seq_len, _ = src_tokens.size()
# (B, C, T, feat)
x = (
src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
.transpose(1, 2)
.contiguous()
)
for input_layer in self.input_layers:
x = input_layer(x)
x = torch.tanh(x)
for conv_layer in self.conv_layers:
x = conv_layer(x)
bsz, _, output_seq_len, _ = x.size()
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) ->
# (T, B, C * feat)
x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1)
input_lengths = src_lengths.clone()
for k, s in self.conv_kernel_sizes_and_strides:
p = k // 2
input_lengths = (input_lengths.float() + 2 * p - k) / s + 1
input_lengths = input_lengths.floor().long()
packed_x = nn.utils.rnn.pack_padded_sequence(x, input_lengths)
h0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_()
c0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_()
packed_outs, _ = self.lstm(packed_x, (h0, c0))
# unpack outputs and apply dropout
x, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_outs)
if self.dropout is not None:
x = self.dropout(x)
encoder_padding_mask = (
lengths_to_padding_mask(output_lengths).to(src_tokens.device).t()
)
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": encoder_padding_mask, # (T, B)
}
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
class MLPAttention(nn.Module):
"""The original attention from Badhanau et al. (2014)
https://arxiv.org/abs/1409.0473, based on a Multi-Layer Perceptron.
The attention score between position i in the encoder and position j in the
decoder is: alpha_ij = V_a * tanh(W_ae * enc_i + W_ad * dec_j + b_a)
"""
def __init__(self, decoder_hidden_state_dim, context_dim, attention_dim):
super().__init__()
self.context_dim = context_dim
self.attention_dim = attention_dim
# W_ae and b_a
self.encoder_proj = nn.Linear(context_dim, self.attention_dim, bias=True)
# W_ad
self.decoder_proj = nn.Linear(
decoder_hidden_state_dim, self.attention_dim, bias=False
)
# V_a
self.to_scores = nn.Linear(self.attention_dim, 1, bias=False)
def forward(self, decoder_state, source_hids, encoder_padding_mask):
"""The expected input dimensions are:
decoder_state: bsz x decoder_hidden_state_dim
source_hids: src_len x bsz x context_dim
encoder_padding_mask: src_len x bsz
"""
src_len, bsz, _ = source_hids.size()
# (src_len*bsz) x context_dim (to feed through linear)
flat_source_hids = source_hids.view(-1, self.context_dim)
# (src_len*bsz) x attention_dim
encoder_component = self.encoder_proj(flat_source_hids)
# src_len x bsz x attention_dim
encoder_component = encoder_component.view(src_len, bsz, self.attention_dim)
# 1 x bsz x attention_dim
decoder_component = self.decoder_proj(decoder_state).unsqueeze(0)
# Sum with broadcasting and apply the non linearity
# src_len x bsz x attention_dim
hidden_att = torch.tanh(
(decoder_component + encoder_component).view(-1, self.attention_dim)
)
# Project onto the reals to get attentions scores (src_len x bsz)
attn_scores = self.to_scores(hidden_att).view(src_len, bsz)
# Mask + softmax (src_len x bsz)
if encoder_padding_mask is not None:
attn_scores = (
attn_scores.float()
.masked_fill_(encoder_padding_mask, float("-inf"))
.type_as(attn_scores)
) # FP16 support: cast to float and back
# srclen x bsz
normalized_masked_attn_scores = F.softmax(attn_scores, dim=0)
# Sum weighted sources (bsz x context_dim)
attn_weighted_context = (
source_hids * normalized_masked_attn_scores.unsqueeze(2)
).sum(dim=0)
return attn_weighted_context, normalized_masked_attn_scores
class LSTMDecoder(FairseqIncrementalDecoder):
def __init__(
self,
dictionary,
embed_dim,
num_layers,
hidden_size,
dropout,
encoder_output_dim,
attention_dim,
output_layer_dim,
):
"""
Args:
dictionary: target text dictionary.
embed_dim: embedding dimension for target tokens.
num_layers: number of LSTM layers.
hidden_size: hidden size for LSTM layers.
dropout: dropout probability. Dropout can be applied to the
embeddings, the LSTM layers, and the context vector.
encoder_output_dim: encoder output dimension (hidden size of
encoder LSTM).
attention_dim: attention dimension for MLP attention.
output_layer_dim: size of the linear layer prior to output
projection.
"""
super().__init__(dictionary)
self.num_layers = num_layers
self.hidden_size = hidden_size
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = nn.Embedding(num_embeddings, embed_dim, padding_idx)
if dropout > 0:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = None
self.layers = nn.ModuleList()
for layer_id in range(num_layers):
input_size = embed_dim if layer_id == 0 else encoder_output_dim
self.layers.append(
nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
)
self.context_dim = encoder_output_dim
self.attention = MLPAttention(
decoder_hidden_state_dim=hidden_size,
context_dim=encoder_output_dim,
attention_dim=attention_dim,
)
self.deep_output_layer = nn.Linear(
hidden_size + encoder_output_dim + embed_dim, output_layer_dim
)
self.output_projection = nn.Linear(output_layer_dim, num_embeddings)
def forward(
self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
):
encoder_padding_mask = encoder_out["encoder_padding_mask"]
encoder_outs = encoder_out["encoder_out"]
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bsz, seqlen = prev_output_tokens.size()
srclen = encoder_outs.size(0)
# embed tokens
embeddings = self.embed_tokens(prev_output_tokens)
x = embeddings
if self.dropout is not None:
x = self.dropout(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# initialize previous states (or get from cache during incremental
# generation)
cached_state = utils.get_incremental_state(
self, incremental_state, "cached_state"
)
if cached_state is not None:
prev_hiddens, prev_cells = cached_state
else:
prev_hiddens = [encoder_out["encoder_out"].mean(dim=0)] * self.num_layers
prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers
attn_scores = x.new_zeros(bsz, srclen)
attention_outs = []
outs = []
for j in range(seqlen):
input = x[j, :, :]
attention_out = None
for i, layer in enumerate(self.layers):
# the previous state is one layer below except for the bottom
# layer where the previous state is the state emitted by the
# top layer
hidden, cell = layer(
input,
(
prev_hiddens[(i - 1) % self.num_layers],
prev_cells[(i - 1) % self.num_layers],
),
)
if self.dropout is not None:
hidden = self.dropout(hidden)
prev_hiddens[i] = hidden
prev_cells[i] = cell
if attention_out is None:
attention_out, attn_scores = self.attention(
hidden, encoder_outs, encoder_padding_mask
)
if self.dropout is not None:
attention_out = self.dropout(attention_out)
attention_outs.append(attention_out)
input = attention_out
# collect the output of the top layer
outs.append(hidden)
# cache previous states (no-op except during incremental generation)
utils.set_incremental_state(
self, incremental_state, "cached_state", (prev_hiddens, prev_cells)
)
# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
attention_outs_concat = torch.cat(attention_outs, dim=0).view(
seqlen, bsz, self.context_dim
)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
attention_outs_concat = attention_outs_concat.transpose(0, 1)
# concat LSTM output, attention output and embedding
# before output projection
x = torch.cat((x, attention_outs_concat, embeddings), dim=2)
x = self.deep_output_layer(x)
x = torch.tanh(x)
if self.dropout is not None:
x = self.dropout(x)
# project back to size of vocabulary
x = self.output_projection(x)
# to return the full attn_scores tensor, we need to fix the decoder
# to account for subsampling input frames
# return x, attn_scores
return x, None
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(
self, incremental_state, "cached_state"
)
if cached_state is None:
return
def reorder_state(state):
if isinstance(state, list):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, "cached_state", new_state)
@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard")
def berard(args):
"""The original version: "End-to-End Automatic Speech Translation of
Audiobooks" (https://arxiv.org/abs/1802.04200)
"""
args.input_layers = getattr(args, "input_layers", "[256, 128]")
args.conv_layers = getattr(args, "conv_layers", "[(16, 3, 2), (16, 3, 2)]")
args.num_blstm_layers = getattr(args, "num_blstm_layers", 3)
args.lstm_size = getattr(args, "lstm_size", 256)
args.dropout = getattr(args, "dropout", 0.2)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
args.decoder_num_layers = getattr(args, "decoder_num_layers", 2)
args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 512)
args.attention_dim = getattr(args, "attention_dim", 512)
args.output_layer_dim = getattr(args, "output_layer_dim", 128)
args.load_pretrained_encoder_from = getattr(
args, "load_pretrained_encoder_from", None
)
args.load_pretrained_decoder_from = getattr(
args, "load_pretrained_decoder_from", None
)
@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_256_3_3")
def berard_256_3_3(args):
"""Used in
* "Harnessing Indirect Training Data for End-to-End Automatic Speech
Translation: Tricks of the Trade" (https://arxiv.org/abs/1909.06515)
* "CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus"
(https://arxiv.org/pdf/2002.01320.pdf)
* "Self-Supervised Representations Improve End-to-End Speech Translation"
(https://arxiv.org/abs/2006.12124)
"""
args.decoder_num_layers = getattr(args, "decoder_num_layers", 3)
berard(args)
@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_3_2")
def berard_512_3_2(args):
args.num_blstm_layers = getattr(args, "num_blstm_layers", 3)
args.lstm_size = getattr(args, "lstm_size", 512)
args.dropout = getattr(args, "dropout", 0.3)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
args.decoder_num_layers = getattr(args, "decoder_num_layers", 2)
args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024)
args.attention_dim = getattr(args, "attention_dim", 512)
args.output_layer_dim = getattr(args, "output_layer_dim", 256)
berard(args)
@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_5_3")
def berard_512_5_3(args):
args.num_blstm_layers = getattr(args, "num_blstm_layers", 5)
args.lstm_size = getattr(args, "lstm_size", 512)
args.dropout = getattr(args, "dropout", 0.3)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
args.decoder_num_layers = getattr(args, "decoder_num_layers", 3)
args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024)
args.attention_dim = getattr(args, "attention_dim", 512)
args.output_layer_dim = getattr(args, "output_layer_dim", 256)
berard(args)
|