Spaces:
Running
on
Zero
Running
on
Zero
| """RNN decoder module.""" | |
| import logging | |
| import math | |
| import random | |
| from argparse import Namespace | |
| import numpy as np | |
| import six | |
| import torch | |
| import torch.nn.functional as F | |
| from funasr_detach.models.transformer.utils.scorers.ctc_prefix_score import ( | |
| CTCPrefixScore, | |
| ) | |
| from funasr_detach.models.transformer.utils.scorers.ctc_prefix_score import ( | |
| CTCPrefixScoreTH, | |
| ) | |
| from funasr_detach.models.transformer.utils.scorers.scorer_interface import ( | |
| ScorerInterface, | |
| ) | |
| from funasr_detach.metrics import end_detect | |
| from funasr_detach.models.transformer.utils.nets_utils import mask_by_length | |
| from funasr_detach.models.transformer.utils.nets_utils import pad_list | |
| from funasr_detach.metrics.compute_acc import th_accuracy | |
| from funasr_detach.models.transformer.utils.nets_utils import to_device | |
| from funasr_detach.models.language_model.rnn.attentions import att_to_numpy | |
| MAX_DECODER_OUTPUT = 5 | |
| CTC_SCORING_RATIO = 1.5 | |
| class Decoder(torch.nn.Module, ScorerInterface): | |
| """Decoder module | |
| :param int eprojs: encoder projection units | |
| :param int odim: dimension of outputs | |
| :param str dtype: gru or lstm | |
| :param int dlayers: decoder layers | |
| :param int dunits: decoder units | |
| :param int sos: start of sequence symbol id | |
| :param int eos: end of sequence symbol id | |
| :param torch.nn.Module att: attention module | |
| :param int verbose: verbose level | |
| :param list char_list: list of character strings | |
| :param ndarray labeldist: distribution of label smoothing | |
| :param float lsm_weight: label smoothing weight | |
| :param float sampling_probability: scheduled sampling probability | |
| :param float dropout: dropout rate | |
| :param float context_residual: if True, use context vector for token generation | |
| :param float replace_sos: use for multilingual (speech/text) translation | |
| """ | |
| def __init__( | |
| self, | |
| eprojs, | |
| odim, | |
| dtype, | |
| dlayers, | |
| dunits, | |
| sos, | |
| eos, | |
| att, | |
| verbose=0, | |
| char_list=None, | |
| labeldist=None, | |
| lsm_weight=0.0, | |
| sampling_probability=0.0, | |
| dropout=0.0, | |
| context_residual=False, | |
| replace_sos=False, | |
| num_encs=1, | |
| ): | |
| torch.nn.Module.__init__(self) | |
| self.dtype = dtype | |
| self.dunits = dunits | |
| self.dlayers = dlayers | |
| self.context_residual = context_residual | |
| self.embed = torch.nn.Embedding(odim, dunits) | |
| self.dropout_emb = torch.nn.Dropout(p=dropout) | |
| self.decoder = torch.nn.ModuleList() | |
| self.dropout_dec = torch.nn.ModuleList() | |
| self.decoder += [ | |
| ( | |
| torch.nn.LSTMCell(dunits + eprojs, dunits) | |
| if self.dtype == "lstm" | |
| else torch.nn.GRUCell(dunits + eprojs, dunits) | |
| ) | |
| ] | |
| self.dropout_dec += [torch.nn.Dropout(p=dropout)] | |
| for _ in six.moves.range(1, self.dlayers): | |
| self.decoder += [ | |
| ( | |
| torch.nn.LSTMCell(dunits, dunits) | |
| if self.dtype == "lstm" | |
| else torch.nn.GRUCell(dunits, dunits) | |
| ) | |
| ] | |
| self.dropout_dec += [torch.nn.Dropout(p=dropout)] | |
| # NOTE: dropout is applied only for the vertical connections | |
| # see https://arxiv.org/pdf/1409.2329.pdf | |
| self.ignore_id = -1 | |
| if context_residual: | |
| self.output = torch.nn.Linear(dunits + eprojs, odim) | |
| else: | |
| self.output = torch.nn.Linear(dunits, odim) | |
| self.loss = None | |
| self.att = att | |
| self.dunits = dunits | |
| self.sos = sos | |
| self.eos = eos | |
| self.odim = odim | |
| self.verbose = verbose | |
| self.char_list = char_list | |
| # for label smoothing | |
| self.labeldist = labeldist | |
| self.vlabeldist = None | |
| self.lsm_weight = lsm_weight | |
| self.sampling_probability = sampling_probability | |
| self.dropout = dropout | |
| self.num_encs = num_encs | |
| # for multilingual E2E-ST | |
| self.replace_sos = replace_sos | |
| self.logzero = -10000000000.0 | |
| def zero_state(self, hs_pad): | |
| return hs_pad.new_zeros(hs_pad.size(0), self.dunits) | |
| def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev): | |
| if self.dtype == "lstm": | |
| z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0])) | |
| for i in six.moves.range(1, self.dlayers): | |
| z_list[i], c_list[i] = self.decoder[i]( | |
| self.dropout_dec[i - 1](z_list[i - 1]), (z_prev[i], c_prev[i]) | |
| ) | |
| else: | |
| z_list[0] = self.decoder[0](ey, z_prev[0]) | |
| for i in six.moves.range(1, self.dlayers): | |
| z_list[i] = self.decoder[i]( | |
| self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i] | |
| ) | |
| return z_list, c_list | |
| def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None): | |
| """Decoder forward | |
| :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) | |
| [in multi-encoder case, | |
| list of torch.Tensor, | |
| [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ] | |
| :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) | |
| [in multi-encoder case, list of torch.Tensor, | |
| [(B), (B), ..., ] | |
| :param torch.Tensor ys_pad: batch of padded character id sequence tensor | |
| (B, Lmax) | |
| :param int strm_idx: stream index indicates the index of decoding stream. | |
| :param torch.Tensor lang_ids: batch of target language id tensor (B, 1) | |
| :return: attention loss value | |
| :rtype: torch.Tensor | |
| :return: accuracy | |
| :rtype: float | |
| """ | |
| # to support mutiple encoder asr mode, in single encoder mode, | |
| # convert torch.Tensor to List of torch.Tensor | |
| if self.num_encs == 1: | |
| hs_pad = [hs_pad] | |
| hlens = [hlens] | |
| # TODO(kan-bayashi): need to make more smart way | |
| ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys | |
| # attention index for the attention module | |
| # in SPA (speaker parallel attention), | |
| # att_idx is used to select attention module. In other cases, it is 0. | |
| att_idx = min(strm_idx, len(self.att) - 1) | |
| # hlens should be list of list of integer | |
| hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)] | |
| self.loss = None | |
| # prepare input and output word sequences with sos/eos IDs | |
| eos = ys[0].new([self.eos]) | |
| sos = ys[0].new([self.sos]) | |
| if self.replace_sos: | |
| ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)] | |
| else: | |
| ys_in = [torch.cat([sos, y], dim=0) for y in ys] | |
| ys_out = [torch.cat([y, eos], dim=0) for y in ys] | |
| # padding for ys with -1 | |
| # pys: utt x olen | |
| ys_in_pad = pad_list(ys_in, self.eos) | |
| ys_out_pad = pad_list(ys_out, self.ignore_id) | |
| # get dim, length info | |
| batch = ys_out_pad.size(0) | |
| olength = ys_out_pad.size(1) | |
| for idx in range(self.num_encs): | |
| logging.info( | |
| self.__class__.__name__ | |
| + "Number of Encoder:{}; enc{}: input lengths: {}.".format( | |
| self.num_encs, idx + 1, hlens[idx] | |
| ) | |
| ) | |
| logging.info( | |
| self.__class__.__name__ | |
| + " output lengths: " | |
| + str([y.size(0) for y in ys_out]) | |
| ) | |
| # initialization | |
| c_list = [self.zero_state(hs_pad[0])] | |
| z_list = [self.zero_state(hs_pad[0])] | |
| for _ in six.moves.range(1, self.dlayers): | |
| c_list.append(self.zero_state(hs_pad[0])) | |
| z_list.append(self.zero_state(hs_pad[0])) | |
| z_all = [] | |
| if self.num_encs == 1: | |
| att_w = None | |
| self.att[att_idx].reset() # reset pre-computation of h | |
| else: | |
| att_w_list = [None] * (self.num_encs + 1) # atts + han | |
| att_c_list = [None] * (self.num_encs) # atts | |
| for idx in range(self.num_encs + 1): | |
| self.att[idx].reset() # reset pre-computation of h in atts and han | |
| # pre-computation of embedding | |
| eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim | |
| # loop for an output sequence | |
| for i in six.moves.range(olength): | |
| if self.num_encs == 1: | |
| att_c, att_w = self.att[att_idx]( | |
| hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w | |
| ) | |
| else: | |
| for idx in range(self.num_encs): | |
| att_c_list[idx], att_w_list[idx] = self.att[idx]( | |
| hs_pad[idx], | |
| hlens[idx], | |
| self.dropout_dec[0](z_list[0]), | |
| att_w_list[idx], | |
| ) | |
| hs_pad_han = torch.stack(att_c_list, dim=1) | |
| hlens_han = [self.num_encs] * len(ys_in) | |
| att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( | |
| hs_pad_han, | |
| hlens_han, | |
| self.dropout_dec[0](z_list[0]), | |
| att_w_list[self.num_encs], | |
| ) | |
| if i > 0 and random.random() < self.sampling_probability: | |
| logging.info(" scheduled sampling ") | |
| z_out = self.output(z_all[-1]) | |
| z_out = np.argmax(z_out.detach().cpu(), axis=1) | |
| z_out = self.dropout_emb(self.embed(to_device(hs_pad[0], z_out))) | |
| ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim) | |
| else: | |
| ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) | |
| z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) | |
| if self.context_residual: | |
| z_all.append( | |
| torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) | |
| ) # utt x (zdim + hdim) | |
| else: | |
| z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim) | |
| z_all = torch.stack(z_all, dim=1).view(batch * olength, -1) | |
| # compute loss | |
| y_all = self.output(z_all) | |
| self.loss = F.cross_entropy( | |
| y_all, | |
| ys_out_pad.view(-1), | |
| ignore_index=self.ignore_id, | |
| reduction="mean", | |
| ) | |
| # compute perplexity | |
| ppl = math.exp(self.loss.item()) | |
| # -1: eos, which is removed in the loss computation | |
| self.loss *= np.mean([len(x) for x in ys_in]) - 1 | |
| acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id) | |
| logging.info("att loss:" + "".join(str(self.loss.item()).split("\n"))) | |
| # show predicted character sequence for debug | |
| if self.verbose > 0 and self.char_list is not None: | |
| ys_hat = y_all.view(batch, olength, -1) | |
| ys_true = ys_out_pad | |
| for (i, y_hat), y_true in zip( | |
| enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy() | |
| ): | |
| if i == MAX_DECODER_OUTPUT: | |
| break | |
| idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1) | |
| idx_true = y_true[y_true != self.ignore_id] | |
| seq_hat = [self.char_list[int(idx)] for idx in idx_hat] | |
| seq_true = [self.char_list[int(idx)] for idx in idx_true] | |
| seq_hat = "".join(seq_hat) | |
| seq_true = "".join(seq_true) | |
| logging.info("groundtruth[%d]: " % i + seq_true) | |
| logging.info("prediction [%d]: " % i + seq_hat) | |
| if self.labeldist is not None: | |
| if self.vlabeldist is None: | |
| self.vlabeldist = to_device(hs_pad[0], torch.from_numpy(self.labeldist)) | |
| loss_reg = -torch.sum( | |
| (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0 | |
| ) / len(ys_in) | |
| self.loss = (1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg | |
| return self.loss, acc, ppl | |
| def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0): | |
| """beam search implementation | |
| :param torch.Tensor h: encoder hidden state (T, eprojs) | |
| [in multi-encoder case, list of torch.Tensor, | |
| [(T1, eprojs), (T2, eprojs), ...] ] | |
| :param torch.Tensor lpz: ctc log softmax output (T, odim) | |
| [in multi-encoder case, list of torch.Tensor, | |
| [(T1, odim), (T2, odim), ...] ] | |
| :param Namespace recog_args: argument Namespace containing options | |
| :param char_list: list of character strings | |
| :param torch.nn.Module rnnlm: language module | |
| :param int strm_idx: | |
| stream index for speaker parallel attention in multi-speaker case | |
| :return: N-best decoding results | |
| :rtype: list of dicts | |
| """ | |
| # to support mutiple encoder asr mode, in single encoder mode, | |
| # convert torch.Tensor to List of torch.Tensor | |
| if self.num_encs == 1: | |
| h = [h] | |
| lpz = [lpz] | |
| if self.num_encs > 1 and lpz is None: | |
| lpz = [lpz] * self.num_encs | |
| for idx in range(self.num_encs): | |
| logging.info( | |
| "Number of Encoder:{}; enc{}: input lengths: {}.".format( | |
| self.num_encs, idx + 1, h[0].size(0) | |
| ) | |
| ) | |
| att_idx = min(strm_idx, len(self.att) - 1) | |
| # initialization | |
| c_list = [self.zero_state(h[0].unsqueeze(0))] | |
| z_list = [self.zero_state(h[0].unsqueeze(0))] | |
| for _ in six.moves.range(1, self.dlayers): | |
| c_list.append(self.zero_state(h[0].unsqueeze(0))) | |
| z_list.append(self.zero_state(h[0].unsqueeze(0))) | |
| if self.num_encs == 1: | |
| a = None | |
| self.att[att_idx].reset() # reset pre-computation of h | |
| else: | |
| a = [None] * (self.num_encs + 1) # atts + han | |
| att_w_list = [None] * (self.num_encs + 1) # atts + han | |
| att_c_list = [None] * (self.num_encs) # atts | |
| for idx in range(self.num_encs + 1): | |
| self.att[idx].reset() # reset pre-computation of h in atts and han | |
| # search parms | |
| beam = recog_args.beam_size | |
| penalty = recog_args.penalty | |
| ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT | |
| if lpz[0] is not None and self.num_encs > 1: | |
| # weights-ctc, | |
| # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss | |
| weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( | |
| recog_args.weights_ctc_dec | |
| ) # normalize | |
| logging.info( | |
| "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]) | |
| ) | |
| else: | |
| weights_ctc_dec = [1.0] | |
| # preprate sos | |
| if self.replace_sos and recog_args.tgt_lang: | |
| y = char_list.index(recog_args.tgt_lang) | |
| else: | |
| y = self.sos | |
| logging.info("<sos> index: " + str(y)) | |
| logging.info("<sos> mark: " + char_list[y]) | |
| vy = h[0].new_zeros(1).long() | |
| maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)]) | |
| if recog_args.maxlenratio != 0: | |
| # maxlen >= 1 | |
| maxlen = max(1, int(recog_args.maxlenratio * maxlen)) | |
| minlen = int(recog_args.minlenratio * maxlen) | |
| logging.info("max output length: " + str(maxlen)) | |
| logging.info("min output length: " + str(minlen)) | |
| # initialize hypothesis | |
| if rnnlm: | |
| hyp = { | |
| "score": 0.0, | |
| "yseq": [y], | |
| "c_prev": c_list, | |
| "z_prev": z_list, | |
| "a_prev": a, | |
| "rnnlm_prev": None, | |
| } | |
| else: | |
| hyp = { | |
| "score": 0.0, | |
| "yseq": [y], | |
| "c_prev": c_list, | |
| "z_prev": z_list, | |
| "a_prev": a, | |
| } | |
| if lpz[0] is not None: | |
| ctc_prefix_score = [ | |
| CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np) | |
| for idx in range(self.num_encs) | |
| ] | |
| hyp["ctc_state_prev"] = [ | |
| ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs) | |
| ] | |
| hyp["ctc_score_prev"] = [0.0] * self.num_encs | |
| if ctc_weight != 1.0: | |
| # pre-pruning based on attention scores | |
| ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO)) | |
| else: | |
| ctc_beam = lpz[0].shape[-1] | |
| hyps = [hyp] | |
| ended_hyps = [] | |
| for i in six.moves.range(maxlen): | |
| logging.debug("position " + str(i)) | |
| hyps_best_kept = [] | |
| for hyp in hyps: | |
| vy[0] = hyp["yseq"][i] | |
| ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim | |
| if self.num_encs == 1: | |
| att_c, att_w = self.att[att_idx]( | |
| h[0].unsqueeze(0), | |
| [h[0].size(0)], | |
| self.dropout_dec[0](hyp["z_prev"][0]), | |
| hyp["a_prev"], | |
| ) | |
| else: | |
| for idx in range(self.num_encs): | |
| att_c_list[idx], att_w_list[idx] = self.att[idx]( | |
| h[idx].unsqueeze(0), | |
| [h[idx].size(0)], | |
| self.dropout_dec[0](hyp["z_prev"][0]), | |
| hyp["a_prev"][idx], | |
| ) | |
| h_han = torch.stack(att_c_list, dim=1) | |
| att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( | |
| h_han, | |
| [self.num_encs], | |
| self.dropout_dec[0](hyp["z_prev"][0]), | |
| hyp["a_prev"][self.num_encs], | |
| ) | |
| ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) | |
| z_list, c_list = self.rnn_forward( | |
| ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"] | |
| ) | |
| # get nbest local scores and their ids | |
| if self.context_residual: | |
| logits = self.output( | |
| torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) | |
| ) | |
| else: | |
| logits = self.output(self.dropout_dec[-1](z_list[-1])) | |
| local_att_scores = F.log_softmax(logits, dim=1) | |
| if rnnlm: | |
| rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) | |
| local_scores = ( | |
| local_att_scores + recog_args.lm_weight * local_lm_scores | |
| ) | |
| else: | |
| local_scores = local_att_scores | |
| if lpz[0] is not None: | |
| local_best_scores, local_best_ids = torch.topk( | |
| local_att_scores, ctc_beam, dim=1 | |
| ) | |
| ctc_scores, ctc_states = ( | |
| [None] * self.num_encs, | |
| [None] * self.num_encs, | |
| ) | |
| for idx in range(self.num_encs): | |
| ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx]( | |
| hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx] | |
| ) | |
| local_scores = (1.0 - ctc_weight) * local_att_scores[ | |
| :, local_best_ids[0] | |
| ] | |
| if self.num_encs == 1: | |
| local_scores += ctc_weight * torch.from_numpy( | |
| ctc_scores[0] - hyp["ctc_score_prev"][0] | |
| ) | |
| else: | |
| for idx in range(self.num_encs): | |
| local_scores += ( | |
| ctc_weight | |
| * weights_ctc_dec[idx] | |
| * torch.from_numpy( | |
| ctc_scores[idx] - hyp["ctc_score_prev"][idx] | |
| ) | |
| ) | |
| if rnnlm: | |
| local_scores += ( | |
| recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] | |
| ) | |
| local_best_scores, joint_best_ids = torch.topk( | |
| local_scores, beam, dim=1 | |
| ) | |
| local_best_ids = local_best_ids[:, joint_best_ids[0]] | |
| else: | |
| local_best_scores, local_best_ids = torch.topk( | |
| local_scores, beam, dim=1 | |
| ) | |
| for j in six.moves.range(beam): | |
| new_hyp = {} | |
| # [:] is needed! | |
| new_hyp["z_prev"] = z_list[:] | |
| new_hyp["c_prev"] = c_list[:] | |
| if self.num_encs == 1: | |
| new_hyp["a_prev"] = att_w[:] | |
| else: | |
| new_hyp["a_prev"] = [ | |
| att_w_list[idx][:] for idx in range(self.num_encs + 1) | |
| ] | |
| new_hyp["score"] = hyp["score"] + local_best_scores[0, j] | |
| new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) | |
| new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] | |
| new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) | |
| if rnnlm: | |
| new_hyp["rnnlm_prev"] = rnnlm_state | |
| if lpz[0] is not None: | |
| new_hyp["ctc_state_prev"] = [ | |
| ctc_states[idx][joint_best_ids[0, j]] | |
| for idx in range(self.num_encs) | |
| ] | |
| new_hyp["ctc_score_prev"] = [ | |
| ctc_scores[idx][joint_best_ids[0, j]] | |
| for idx in range(self.num_encs) | |
| ] | |
| # will be (2 x beam) hyps at most | |
| hyps_best_kept.append(new_hyp) | |
| hyps_best_kept = sorted( | |
| hyps_best_kept, key=lambda x: x["score"], reverse=True | |
| )[:beam] | |
| # sort and get nbest | |
| hyps = hyps_best_kept | |
| logging.debug("number of pruned hypotheses: " + str(len(hyps))) | |
| logging.debug( | |
| "best hypo: " | |
| + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) | |
| ) | |
| # add eos in the final loop to avoid that there are no ended hyps | |
| if i == maxlen - 1: | |
| logging.info("adding <eos> in the last position in the loop") | |
| for hyp in hyps: | |
| hyp["yseq"].append(self.eos) | |
| # add ended hypotheses to a final list, | |
| # and removed them from current hypotheses | |
| # (this will be a problem, number of hyps < beam) | |
| remained_hyps = [] | |
| for hyp in hyps: | |
| if hyp["yseq"][-1] == self.eos: | |
| # only store the sequence that has more than minlen outputs | |
| # also add penalty | |
| if len(hyp["yseq"]) > minlen: | |
| hyp["score"] += (i + 1) * penalty | |
| if rnnlm: # Word LM needs to add final <eos> score | |
| hyp["score"] += recog_args.lm_weight * rnnlm.final( | |
| hyp["rnnlm_prev"] | |
| ) | |
| ended_hyps.append(hyp) | |
| else: | |
| remained_hyps.append(hyp) | |
| # end detection | |
| if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: | |
| logging.info("end detected at %d", i) | |
| break | |
| hyps = remained_hyps | |
| if len(hyps) > 0: | |
| logging.debug("remaining hypotheses: " + str(len(hyps))) | |
| else: | |
| logging.info("no hypothesis. Finish decoding.") | |
| break | |
| for hyp in hyps: | |
| logging.debug( | |
| "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) | |
| ) | |
| logging.debug("number of ended hypotheses: " + str(len(ended_hyps))) | |
| nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ | |
| : min(len(ended_hyps), recog_args.nbest) | |
| ] | |
| # check number of hypotheses | |
| if len(nbest_hyps) == 0: | |
| logging.warning( | |
| "there is no N-best results, " | |
| "perform recognition again with smaller minlenratio." | |
| ) | |
| # should copy because Namespace will be overwritten globally | |
| recog_args = Namespace(**vars(recog_args)) | |
| recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) | |
| if self.num_encs == 1: | |
| return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm) | |
| else: | |
| return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) | |
| logging.info("total log probability: " + str(nbest_hyps[0]["score"])) | |
| logging.info( | |
| "normalized log probability: " | |
| + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) | |
| ) | |
| # remove sos | |
| return nbest_hyps | |
| def recognize_beam_batch( | |
| self, | |
| h, | |
| hlens, | |
| lpz, | |
| recog_args, | |
| char_list, | |
| rnnlm=None, | |
| normalize_score=True, | |
| strm_idx=0, | |
| lang_ids=None, | |
| ): | |
| # to support mutiple encoder asr mode, in single encoder mode, | |
| # convert torch.Tensor to List of torch.Tensor | |
| if self.num_encs == 1: | |
| h = [h] | |
| hlens = [hlens] | |
| lpz = [lpz] | |
| if self.num_encs > 1 and lpz is None: | |
| lpz = [lpz] * self.num_encs | |
| att_idx = min(strm_idx, len(self.att) - 1) | |
| for idx in range(self.num_encs): | |
| logging.info( | |
| "Number of Encoder:{}; enc{}: input lengths: {}.".format( | |
| self.num_encs, idx + 1, h[idx].size(1) | |
| ) | |
| ) | |
| h[idx] = mask_by_length(h[idx], hlens[idx], 0.0) | |
| # search params | |
| batch = len(hlens[0]) | |
| beam = recog_args.beam_size | |
| penalty = recog_args.penalty | |
| ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT | |
| att_weight = 1.0 - ctc_weight | |
| ctc_margin = getattr( | |
| recog_args, "ctc_window_margin", 0 | |
| ) # use getattr to keep compatibility | |
| # weights-ctc, | |
| # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss | |
| if lpz[0] is not None and self.num_encs > 1: | |
| weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( | |
| recog_args.weights_ctc_dec | |
| ) # normalize | |
| logging.info( | |
| "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]) | |
| ) | |
| else: | |
| weights_ctc_dec = [1.0] | |
| n_bb = batch * beam | |
| pad_b = to_device(h[0], torch.arange(batch) * beam).view(-1, 1) | |
| max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)]) | |
| if recog_args.maxlenratio == 0: | |
| maxlen = max_hlen | |
| else: | |
| maxlen = max(1, int(recog_args.maxlenratio * max_hlen)) | |
| minlen = int(recog_args.minlenratio * max_hlen) | |
| logging.info("max output length: " + str(maxlen)) | |
| logging.info("min output length: " + str(minlen)) | |
| # initialization | |
| c_prev = [ | |
| to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) | |
| ] | |
| z_prev = [ | |
| to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) | |
| ] | |
| c_list = [ | |
| to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) | |
| ] | |
| z_list = [ | |
| to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) | |
| ] | |
| vscores = to_device(h[0], torch.zeros(batch, beam)) | |
| rnnlm_state = None | |
| if self.num_encs == 1: | |
| a_prev = [None] | |
| att_w_list, ctc_scorer, ctc_state = [None], [None], [None] | |
| self.att[att_idx].reset() # reset pre-computation of h | |
| else: | |
| a_prev = [None] * (self.num_encs + 1) # atts + han | |
| att_w_list = [None] * (self.num_encs + 1) # atts + han | |
| att_c_list = [None] * (self.num_encs) # atts | |
| ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * (self.num_encs) | |
| for idx in range(self.num_encs + 1): | |
| self.att[idx].reset() # reset pre-computation of h in atts and han | |
| if self.replace_sos and recog_args.tgt_lang: | |
| logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang))) | |
| logging.info("<sos> mark: " + recog_args.tgt_lang) | |
| yseq = [ | |
| [char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb) | |
| ] | |
| elif lang_ids is not None: | |
| # NOTE: used for evaluation during training | |
| yseq = [ | |
| [lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb) | |
| ] | |
| else: | |
| logging.info("<sos> index: " + str(self.sos)) | |
| logging.info("<sos> mark: " + char_list[self.sos]) | |
| yseq = [[self.sos] for _ in six.moves.range(n_bb)] | |
| accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)] | |
| stop_search = [False for _ in six.moves.range(batch)] | |
| nbest_hyps = [[] for _ in six.moves.range(batch)] | |
| ended_hyps = [[] for _ in range(batch)] | |
| exp_hlens = [ | |
| hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous() | |
| for idx in range(self.num_encs) | |
| ] | |
| exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)] | |
| exp_h = [ | |
| h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() | |
| for idx in range(self.num_encs) | |
| ] | |
| exp_h = [ | |
| exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2]) | |
| for idx in range(self.num_encs) | |
| ] | |
| if lpz[0] is not None: | |
| scoring_num = min( | |
| ( | |
| int(beam * CTC_SCORING_RATIO) | |
| if att_weight > 0.0 and not lpz[0].is_cuda | |
| else 0 | |
| ), | |
| lpz[0].size(-1), | |
| ) | |
| ctc_scorer = [ | |
| CTCPrefixScoreTH( | |
| lpz[idx], | |
| hlens[idx], | |
| 0, | |
| self.eos, | |
| margin=ctc_margin, | |
| ) | |
| for idx in range(self.num_encs) | |
| ] | |
| for i in six.moves.range(maxlen): | |
| logging.debug("position " + str(i)) | |
| vy = to_device(h[0], torch.LongTensor(self._get_last_yseq(yseq))) | |
| ey = self.dropout_emb(self.embed(vy)) | |
| if self.num_encs == 1: | |
| att_c, att_w = self.att[att_idx]( | |
| exp_h[0], exp_hlens[0], self.dropout_dec[0](z_prev[0]), a_prev[0] | |
| ) | |
| att_w_list = [att_w] | |
| else: | |
| for idx in range(self.num_encs): | |
| att_c_list[idx], att_w_list[idx] = self.att[idx]( | |
| exp_h[idx], | |
| exp_hlens[idx], | |
| self.dropout_dec[0](z_prev[0]), | |
| a_prev[idx], | |
| ) | |
| exp_h_han = torch.stack(att_c_list, dim=1) | |
| att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( | |
| exp_h_han, | |
| [self.num_encs] * n_bb, | |
| self.dropout_dec[0](z_prev[0]), | |
| a_prev[self.num_encs], | |
| ) | |
| ey = torch.cat((ey, att_c), dim=1) | |
| # attention decoder | |
| z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev) | |
| if self.context_residual: | |
| logits = self.output( | |
| torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) | |
| ) | |
| else: | |
| logits = self.output(self.dropout_dec[-1](z_list[-1])) | |
| local_scores = att_weight * F.log_softmax(logits, dim=1) | |
| # rnnlm | |
| if rnnlm: | |
| rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb) | |
| local_scores = local_scores + recog_args.lm_weight * local_lm_scores | |
| # ctc | |
| if ctc_scorer[0]: | |
| local_scores[:, 0] = self.logzero # avoid choosing blank | |
| part_ids = ( | |
| torch.topk(local_scores, scoring_num, dim=-1)[1] | |
| if scoring_num > 0 | |
| else None | |
| ) | |
| for idx in range(self.num_encs): | |
| att_w = att_w_list[idx] | |
| att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0] | |
| local_ctc_scores, ctc_state[idx] = ctc_scorer[idx]( | |
| yseq, ctc_state[idx], part_ids, att_w_ | |
| ) | |
| local_scores = ( | |
| local_scores | |
| + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores | |
| ) | |
| local_scores = local_scores.view(batch, beam, self.odim) | |
| if i == 0: | |
| local_scores[:, 1:, :] = self.logzero | |
| # accumulate scores | |
| eos_vscores = local_scores[:, :, self.eos] + vscores | |
| vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim) | |
| vscores[:, :, self.eos] = self.logzero | |
| vscores = (vscores + local_scores).view(batch, -1) | |
| # global pruning | |
| accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1) | |
| accum_odim_ids = ( | |
| torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist() | |
| ) | |
| accum_padded_beam_ids = ( | |
| (accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist() | |
| ) | |
| y_prev = yseq[:][:] | |
| yseq = self._index_select_list(yseq, accum_padded_beam_ids) | |
| yseq = self._append_ids(yseq, accum_odim_ids) | |
| vscores = accum_best_scores | |
| vidx = to_device(h[0], torch.LongTensor(accum_padded_beam_ids)) | |
| a_prev = [] | |
| num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1 | |
| for idx in range(num_atts): | |
| if isinstance(att_w_list[idx], torch.Tensor): | |
| _a_prev = torch.index_select( | |
| att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx | |
| ) | |
| elif isinstance(att_w_list[idx], list): | |
| # handle the case of multi-head attention | |
| _a_prev = [ | |
| torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) | |
| for att_w_one in att_w_list[idx] | |
| ] | |
| else: | |
| # handle the case of location_recurrent when return is a tuple | |
| _a_prev_ = torch.index_select( | |
| att_w_list[idx][0].view(n_bb, -1), 0, vidx | |
| ) | |
| _h_prev_ = torch.index_select( | |
| att_w_list[idx][1][0].view(n_bb, -1), 0, vidx | |
| ) | |
| _c_prev_ = torch.index_select( | |
| att_w_list[idx][1][1].view(n_bb, -1), 0, vidx | |
| ) | |
| _a_prev = (_a_prev_, (_h_prev_, _c_prev_)) | |
| a_prev.append(_a_prev) | |
| z_prev = [ | |
| torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) | |
| for li in range(self.dlayers) | |
| ] | |
| c_prev = [ | |
| torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) | |
| for li in range(self.dlayers) | |
| ] | |
| # pick ended hyps | |
| if i >= minlen: | |
| k = 0 | |
| penalty_i = (i + 1) * penalty | |
| thr = accum_best_scores[:, -1] | |
| for samp_i in six.moves.range(batch): | |
| if stop_search[samp_i]: | |
| k = k + beam | |
| continue | |
| for beam_j in six.moves.range(beam): | |
| _vscore = None | |
| if eos_vscores[samp_i, beam_j] > thr[samp_i]: | |
| yk = y_prev[k][:] | |
| if len(yk) <= min( | |
| hlens[idx][samp_i] for idx in range(self.num_encs) | |
| ): | |
| _vscore = eos_vscores[samp_i][beam_j] + penalty_i | |
| elif i == maxlen - 1: | |
| yk = yseq[k][:] | |
| _vscore = vscores[samp_i][beam_j] + penalty_i | |
| if _vscore: | |
| yk.append(self.eos) | |
| if rnnlm: | |
| _vscore += recog_args.lm_weight * rnnlm.final( | |
| rnnlm_state, index=k | |
| ) | |
| _score = _vscore.data.cpu().numpy() | |
| ended_hyps[samp_i].append( | |
| {"yseq": yk, "vscore": _vscore, "score": _score} | |
| ) | |
| k = k + 1 | |
| # end detection | |
| stop_search = [ | |
| stop_search[samp_i] or end_detect(ended_hyps[samp_i], i) | |
| for samp_i in six.moves.range(batch) | |
| ] | |
| stop_search_summary = list(set(stop_search)) | |
| if len(stop_search_summary) == 1 and stop_search_summary[0]: | |
| break | |
| if rnnlm: | |
| rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx) | |
| if ctc_scorer[0]: | |
| for idx in range(self.num_encs): | |
| ctc_state[idx] = ctc_scorer[idx].index_select_state( | |
| ctc_state[idx], accum_best_ids | |
| ) | |
| torch.cuda.empty_cache() | |
| dummy_hyps = [ | |
| {"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])} | |
| ] | |
| ended_hyps = [ | |
| ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps | |
| for samp_i in six.moves.range(batch) | |
| ] | |
| if normalize_score: | |
| for samp_i in six.moves.range(batch): | |
| for x in ended_hyps[samp_i]: | |
| x["score"] /= len(x["yseq"]) | |
| nbest_hyps = [ | |
| sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[ | |
| : min(len(ended_hyps[samp_i]), recog_args.nbest) | |
| ] | |
| for samp_i in six.moves.range(batch) | |
| ] | |
| return nbest_hyps | |
| def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None): | |
| """Calculate all of attentions | |
| :param torch.Tensor hs_pad: batch of padded hidden state sequences | |
| (B, Tmax, D) | |
| in multi-encoder case, list of torch.Tensor, | |
| [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ] | |
| :param torch.Tensor hlen: batch of lengths of hidden state sequences (B) | |
| [in multi-encoder case, list of torch.Tensor, | |
| [(B), (B), ..., ] | |
| :param torch.Tensor ys_pad: | |
| batch of padded character id sequence tensor (B, Lmax) | |
| :param int strm_idx: | |
| stream index for parallel speaker attention in multi-speaker case | |
| :param torch.Tensor lang_ids: batch of target language id tensor (B, 1) | |
| :return: attention weights with the following shape, | |
| 1) multi-head case => attention weights (B, H, Lmax, Tmax), | |
| 2) multi-encoder case => | |
| [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)] | |
| 3) other case => attention weights (B, Lmax, Tmax). | |
| :rtype: float ndarray | |
| """ | |
| # to support mutiple encoder asr mode, in single encoder mode, | |
| # convert torch.Tensor to List of torch.Tensor | |
| if self.num_encs == 1: | |
| hs_pad = [hs_pad] | |
| hlen = [hlen] | |
| # TODO(kan-bayashi): need to make more smart way | |
| ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys | |
| att_idx = min(strm_idx, len(self.att) - 1) | |
| # hlen should be list of list of integer | |
| hlen = [list(map(int, hlen[idx])) for idx in range(self.num_encs)] | |
| self.loss = None | |
| # prepare input and output word sequences with sos/eos IDs | |
| eos = ys[0].new([self.eos]) | |
| sos = ys[0].new([self.sos]) | |
| if self.replace_sos: | |
| ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)] | |
| else: | |
| ys_in = [torch.cat([sos, y], dim=0) for y in ys] | |
| ys_out = [torch.cat([y, eos], dim=0) for y in ys] | |
| # padding for ys with -1 | |
| # pys: utt x olen | |
| ys_in_pad = pad_list(ys_in, self.eos) | |
| ys_out_pad = pad_list(ys_out, self.ignore_id) | |
| # get length info | |
| olength = ys_out_pad.size(1) | |
| # initialization | |
| c_list = [self.zero_state(hs_pad[0])] | |
| z_list = [self.zero_state(hs_pad[0])] | |
| for _ in six.moves.range(1, self.dlayers): | |
| c_list.append(self.zero_state(hs_pad[0])) | |
| z_list.append(self.zero_state(hs_pad[0])) | |
| att_ws = [] | |
| if self.num_encs == 1: | |
| att_w = None | |
| self.att[att_idx].reset() # reset pre-computation of h | |
| else: | |
| att_w_list = [None] * (self.num_encs + 1) # atts + han | |
| att_c_list = [None] * (self.num_encs) # atts | |
| for idx in range(self.num_encs + 1): | |
| self.att[idx].reset() # reset pre-computation of h in atts and han | |
| # pre-computation of embedding | |
| eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim | |
| # loop for an output sequence | |
| for i in six.moves.range(olength): | |
| if self.num_encs == 1: | |
| att_c, att_w = self.att[att_idx]( | |
| hs_pad[0], hlen[0], self.dropout_dec[0](z_list[0]), att_w | |
| ) | |
| att_ws.append(att_w) | |
| else: | |
| for idx in range(self.num_encs): | |
| att_c_list[idx], att_w_list[idx] = self.att[idx]( | |
| hs_pad[idx], | |
| hlen[idx], | |
| self.dropout_dec[0](z_list[0]), | |
| att_w_list[idx], | |
| ) | |
| hs_pad_han = torch.stack(att_c_list, dim=1) | |
| hlen_han = [self.num_encs] * len(ys_in) | |
| att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( | |
| hs_pad_han, | |
| hlen_han, | |
| self.dropout_dec[0](z_list[0]), | |
| att_w_list[self.num_encs], | |
| ) | |
| att_ws.append(att_w_list.copy()) | |
| ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) | |
| z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) | |
| if self.num_encs == 1: | |
| # convert to numpy array with the shape (B, Lmax, Tmax) | |
| att_ws = att_to_numpy(att_ws, self.att[att_idx]) | |
| else: | |
| _att_ws = [] | |
| for idx, ws in enumerate(zip(*att_ws)): | |
| ws = att_to_numpy(ws, self.att[idx]) | |
| _att_ws.append(ws) | |
| att_ws = _att_ws | |
| return att_ws | |
| def _get_last_yseq(exp_yseq): | |
| last = [] | |
| for y_seq in exp_yseq: | |
| last.append(y_seq[-1]) | |
| return last | |
| def _append_ids(yseq, ids): | |
| if isinstance(ids, list): | |
| for i, j in enumerate(ids): | |
| yseq[i].append(j) | |
| else: | |
| for i in range(len(yseq)): | |
| yseq[i].append(ids) | |
| return yseq | |
| def _index_select_list(yseq, lst): | |
| new_yseq = [] | |
| for i in lst: | |
| new_yseq.append(yseq[i][:]) | |
| return new_yseq | |
| def _index_select_lm_state(rnnlm_state, dim, vidx): | |
| if isinstance(rnnlm_state, dict): | |
| new_state = {} | |
| for k, v in rnnlm_state.items(): | |
| new_state[k] = [torch.index_select(vi, dim, vidx) for vi in v] | |
| elif isinstance(rnnlm_state, list): | |
| new_state = [] | |
| for i in vidx: | |
| new_state.append(rnnlm_state[int(i)][:]) | |
| return new_state | |
| # scorer interface methods | |
| def init_state(self, x): | |
| # to support mutiple encoder asr mode, in single encoder mode, | |
| # convert torch.Tensor to List of torch.Tensor | |
| if self.num_encs == 1: | |
| x = [x] | |
| c_list = [self.zero_state(x[0].unsqueeze(0))] | |
| z_list = [self.zero_state(x[0].unsqueeze(0))] | |
| for _ in six.moves.range(1, self.dlayers): | |
| c_list.append(self.zero_state(x[0].unsqueeze(0))) | |
| z_list.append(self.zero_state(x[0].unsqueeze(0))) | |
| # TODO(karita): support strm_index for `asr_mix` | |
| strm_index = 0 | |
| att_idx = min(strm_index, len(self.att) - 1) | |
| if self.num_encs == 1: | |
| a = None | |
| self.att[att_idx].reset() # reset pre-computation of h | |
| else: | |
| a = [None] * (self.num_encs + 1) # atts + han | |
| for idx in range(self.num_encs + 1): | |
| self.att[idx].reset() # reset pre-computation of h in atts and han | |
| return dict( | |
| c_prev=c_list[:], | |
| z_prev=z_list[:], | |
| a_prev=a, | |
| workspace=(att_idx, z_list, c_list), | |
| ) | |
| def score(self, yseq, state, x): | |
| # to support mutiple encoder asr mode, in single encoder mode, | |
| # convert torch.Tensor to List of torch.Tensor | |
| if self.num_encs == 1: | |
| x = [x] | |
| att_idx, z_list, c_list = state["workspace"] | |
| vy = yseq[-1].unsqueeze(0) | |
| ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim | |
| if self.num_encs == 1: | |
| att_c, att_w = self.att[att_idx]( | |
| x[0].unsqueeze(0), | |
| [x[0].size(0)], | |
| self.dropout_dec[0](state["z_prev"][0]), | |
| state["a_prev"], | |
| ) | |
| else: | |
| att_w = [None] * (self.num_encs + 1) # atts + han | |
| att_c_list = [None] * (self.num_encs) # atts | |
| for idx in range(self.num_encs): | |
| att_c_list[idx], att_w[idx] = self.att[idx]( | |
| x[idx].unsqueeze(0), | |
| [x[idx].size(0)], | |
| self.dropout_dec[0](state["z_prev"][0]), | |
| state["a_prev"][idx], | |
| ) | |
| h_han = torch.stack(att_c_list, dim=1) | |
| att_c, att_w[self.num_encs] = self.att[self.num_encs]( | |
| h_han, | |
| [self.num_encs], | |
| self.dropout_dec[0](state["z_prev"][0]), | |
| state["a_prev"][self.num_encs], | |
| ) | |
| ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) | |
| z_list, c_list = self.rnn_forward( | |
| ey, z_list, c_list, state["z_prev"], state["c_prev"] | |
| ) | |
| if self.context_residual: | |
| logits = self.output( | |
| torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) | |
| ) | |
| else: | |
| logits = self.output(self.dropout_dec[-1](z_list[-1])) | |
| logp = F.log_softmax(logits, dim=1).squeeze(0) | |
| return ( | |
| logp, | |
| dict( | |
| c_prev=c_list[:], | |
| z_prev=z_list[:], | |
| a_prev=att_w, | |
| workspace=(att_idx, z_list, c_list), | |
| ), | |
| ) | |
| def decoder_for(args, odim, sos, eos, att, labeldist): | |
| return Decoder( | |
| args.eprojs, | |
| odim, | |
| args.dtype, | |
| args.dlayers, | |
| args.dunits, | |
| sos, | |
| eos, | |
| att, | |
| args.verbose, | |
| args.char_list, | |
| labeldist, | |
| args.lsm_weight, | |
| args.sampling_probability, | |
| args.dropout_rate_decoder, | |
| getattr(args, "context_residual", False), # use getattr to keep compatibility | |
| getattr(args, "replace_sos", False), # use getattr to keep compatibility | |
| getattr(args, "num_encs", 1), | |
| ) # use getattr to keep compatibility | |