Spaces:
Runtime error
Runtime error
| # Copyright 2019 Shigeki Karita | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| """Transformer speech recognition model (pytorch).""" | |
| from argparse import Namespace | |
| from distutils.util import strtobool | |
| import logging | |
| import math | |
| import numpy | |
| import torch | |
| from espnet.nets.ctc_prefix_score import CTCPrefixScore | |
| from espnet.nets.e2e_asr_common import end_detect | |
| from espnet.nets.e2e_asr_common import ErrorCalculator | |
| from espnet.nets.pytorch_backend.ctc import CTC | |
| from espnet.nets.pytorch_backend.nets_utils import get_subsample | |
| from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask | |
| from espnet.nets.pytorch_backend.nets_utils import th_accuracy | |
| from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos | |
| from espnet.nets.pytorch_backend.transformer.attention import ( | |
| MultiHeadedAttention, # noqa: H301 | |
| RelPositionMultiHeadedAttention, # noqa: H301 | |
| ) | |
| from espnet.nets.pytorch_backend.transformer.decoder import Decoder | |
| from espnet.nets.pytorch_backend.transformer.encoder import Encoder | |
| from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( | |
| LabelSmoothingLoss, # noqa: H301 | |
| ) | |
| from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask | |
| from espnet.nets.pytorch_backend.transformer.mask import target_mask | |
| from espnet.nets.scorers.ctc import CTCPrefixScorer | |
| class E2E(torch.nn.Module): | |
| """E2E module. | |
| :param int idim: dimension of inputs | |
| :param int odim: dimension of outputs | |
| :param Namespace args: argument Namespace containing options | |
| """ | |
| def add_arguments(parser): | |
| """Add arguments.""" | |
| group = parser.add_argument_group("transformer model setting") | |
| group.add_argument( | |
| "--transformer-init", | |
| type=str, | |
| default="pytorch", | |
| choices=[ | |
| "pytorch", | |
| "xavier_uniform", | |
| "xavier_normal", | |
| "kaiming_uniform", | |
| "kaiming_normal", | |
| ], | |
| help="how to initialize transformer parameters", | |
| ) | |
| group.add_argument( | |
| "--transformer-input-layer", | |
| type=str, | |
| default="conv2d", | |
| choices=["conv3d", "conv2d", "conv1d", "linear", "embed"], | |
| help="transformer input layer type", | |
| ) | |
| group.add_argument( | |
| "--transformer-encoder-attn-layer-type", | |
| type=str, | |
| default="mha", | |
| choices=["mha", "rel_mha", "legacy_rel_mha"], | |
| help="transformer encoder attention layer type", | |
| ) | |
| group.add_argument( | |
| "--transformer-attn-dropout-rate", | |
| default=None, | |
| type=float, | |
| help="dropout in transformer attention. use --dropout-rate if None is set", | |
| ) | |
| group.add_argument( | |
| "--transformer-lr", | |
| default=10.0, | |
| type=float, | |
| help="Initial value of learning rate", | |
| ) | |
| group.add_argument( | |
| "--transformer-warmup-steps", | |
| default=25000, | |
| type=int, | |
| help="optimizer warmup steps", | |
| ) | |
| group.add_argument( | |
| "--transformer-length-normalized-loss", | |
| default=True, | |
| type=strtobool, | |
| help="normalize loss by length", | |
| ) | |
| group.add_argument( | |
| "--dropout-rate", | |
| default=0.0, | |
| type=float, | |
| help="Dropout rate for the encoder", | |
| ) | |
| group.add_argument( | |
| "--macaron-style", | |
| default=False, | |
| type=strtobool, | |
| help="Whether to use macaron style for positionwise layer", | |
| ) | |
| # -- input | |
| group.add_argument( | |
| "--a-upsample-ratio", | |
| default=1, | |
| type=int, | |
| help="Upsample rate for audio", | |
| ) | |
| group.add_argument( | |
| "--relu-type", | |
| default="swish", | |
| type=str, | |
| help="the type of activation layer", | |
| ) | |
| # Encoder | |
| group.add_argument( | |
| "--elayers", | |
| default=4, | |
| type=int, | |
| help="Number of encoder layers (for shared recognition part " | |
| "in multi-speaker asr mode)", | |
| ) | |
| group.add_argument( | |
| "--eunits", | |
| "-u", | |
| default=300, | |
| type=int, | |
| help="Number of encoder hidden units", | |
| ) | |
| group.add_argument( | |
| "--use-cnn-module", | |
| default=False, | |
| type=strtobool, | |
| help="Use convolution module or not", | |
| ) | |
| group.add_argument( | |
| "--cnn-module-kernel", | |
| default=31, | |
| type=int, | |
| help="Kernel size of convolution module.", | |
| ) | |
| # Attention | |
| group.add_argument( | |
| "--adim", | |
| default=320, | |
| type=int, | |
| help="Number of attention transformation dimensions", | |
| ) | |
| group.add_argument( | |
| "--aheads", | |
| default=4, | |
| type=int, | |
| help="Number of heads for multi head attention", | |
| ) | |
| group.add_argument( | |
| "--zero-triu", | |
| default=False, | |
| type=strtobool, | |
| help="If true, zero the uppper triangular part of attention matrix.", | |
| ) | |
| # Relative positional encoding | |
| group.add_argument( | |
| "--rel-pos-type", | |
| type=str, | |
| default="legacy", | |
| choices=["legacy", "latest"], | |
| help="Whether to use the latest relative positional encoding or the legacy one." | |
| "The legacy relative positional encoding will be deprecated in the future." | |
| "More Details can be found in https://github.com/espnet/espnet/pull/2816.", | |
| ) | |
| # Decoder | |
| group.add_argument( | |
| "--dlayers", default=1, type=int, help="Number of decoder layers" | |
| ) | |
| group.add_argument( | |
| "--dunits", default=320, type=int, help="Number of decoder hidden units" | |
| ) | |
| # -- pretrain | |
| group.add_argument("--pretrain-dataset", | |
| default="", | |
| type=str, | |
| help='pre-trained dataset for encoder' | |
| ) | |
| # -- custom name | |
| group.add_argument("--custom-pretrain-name", | |
| default="", | |
| type=str, | |
| help='pre-trained model for encoder' | |
| ) | |
| return parser | |
| def attention_plot_class(self): | |
| """Return PlotAttentionReport.""" | |
| return PlotAttentionReport | |
| def __init__(self, odim, args, ignore_id=-1): | |
| """Construct an E2E object. | |
| :param int odim: dimension of outputs | |
| :param Namespace args: argument Namespace containing options | |
| """ | |
| torch.nn.Module.__init__(self) | |
| if args.transformer_attn_dropout_rate is None: | |
| args.transformer_attn_dropout_rate = args.dropout_rate | |
| # Check the relative positional encoding type | |
| self.rel_pos_type = getattr(args, "rel_pos_type", None) | |
| if self.rel_pos_type is None and args.transformer_encoder_attn_layer_type == "rel_mha": | |
| args.transformer_encoder_attn_layer_type = "legacy_rel_mha" | |
| logging.warning( | |
| "Using legacy_rel_pos and it will be deprecated in the future." | |
| ) | |
| idim = 80 | |
| self.encoder = Encoder( | |
| idim=idim, | |
| attention_dim=args.adim, | |
| attention_heads=args.aheads, | |
| linear_units=args.eunits, | |
| num_blocks=args.elayers, | |
| input_layer=args.transformer_input_layer, | |
| dropout_rate=args.dropout_rate, | |
| positional_dropout_rate=args.dropout_rate, | |
| attention_dropout_rate=args.transformer_attn_dropout_rate, | |
| encoder_attn_layer_type=args.transformer_encoder_attn_layer_type, | |
| macaron_style=args.macaron_style, | |
| use_cnn_module=args.use_cnn_module, | |
| cnn_module_kernel=args.cnn_module_kernel, | |
| zero_triu=getattr(args, "zero_triu", False), | |
| a_upsample_ratio=args.a_upsample_ratio, | |
| relu_type=getattr(args, "relu_type", "swish"), | |
| ) | |
| self.transformer_input_layer = args.transformer_input_layer | |
| self.a_upsample_ratio = args.a_upsample_ratio | |
| if args.mtlalpha < 1: | |
| self.decoder = Decoder( | |
| odim=odim, | |
| attention_dim=args.adim, | |
| attention_heads=args.aheads, | |
| linear_units=args.dunits, | |
| num_blocks=args.dlayers, | |
| dropout_rate=args.dropout_rate, | |
| positional_dropout_rate=args.dropout_rate, | |
| self_attention_dropout_rate=args.transformer_attn_dropout_rate, | |
| src_attention_dropout_rate=args.transformer_attn_dropout_rate, | |
| ) | |
| else: | |
| self.decoder = None | |
| self.blank = 0 | |
| self.sos = odim - 1 | |
| self.eos = odim - 1 | |
| self.odim = odim | |
| self.ignore_id = ignore_id | |
| self.subsample = get_subsample(args, mode="asr", arch="transformer") | |
| # self.lsm_weight = a | |
| self.criterion = LabelSmoothingLoss( | |
| self.odim, | |
| self.ignore_id, | |
| args.lsm_weight, | |
| args.transformer_length_normalized_loss, | |
| ) | |
| self.adim = args.adim | |
| self.mtlalpha = args.mtlalpha | |
| if args.mtlalpha > 0.0: | |
| self.ctc = CTC( | |
| odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True | |
| ) | |
| else: | |
| self.ctc = None | |
| if args.report_cer or args.report_wer: | |
| self.error_calculator = ErrorCalculator( | |
| args.char_list, | |
| args.sym_space, | |
| args.sym_blank, | |
| args.report_cer, | |
| args.report_wer, | |
| ) | |
| else: | |
| self.error_calculator = None | |
| self.rnnlm = None | |
| def scorers(self): | |
| """Scorers.""" | |
| return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) | |
| def encode(self, x, extract_resnet_feats=False): | |
| """Encode acoustic features. | |
| :param ndarray x: source acoustic feature (T, D) | |
| :return: encoder outputs | |
| :rtype: torch.Tensor | |
| """ | |
| self.eval() | |
| x = torch.as_tensor(x).unsqueeze(0) | |
| if extract_resnet_feats: | |
| resnet_feats = self.encoder( | |
| x, | |
| None, | |
| extract_resnet_feats=extract_resnet_feats, | |
| ) | |
| return resnet_feats.squeeze(0) | |
| else: | |
| enc_output, _ = self.encoder(x, None) | |
| return enc_output.squeeze(0) | |