Spaces:
Running
on
Zero
Running
on
Zero
| """Attention modules for RNN.""" | |
| import math | |
| import six | |
| import torch | |
| import torch.nn.functional as F | |
| from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask | |
| from funasr_detach.models.transformer.utils.nets_utils import to_device | |
| def _apply_attention_constraint( | |
| e, last_attended_idx, backward_window=1, forward_window=3 | |
| ): | |
| """Apply monotonic attention constraint. | |
| This function apply the monotonic attention constraint | |
| introduced in `Deep Voice 3: Scaling | |
| Text-to-Speech with Convolutional Sequence Learning`_. | |
| Args: | |
| e (Tensor): Attention energy before applying softmax (1, T). | |
| last_attended_idx (int): The index of the inputs of the last attended [0, T]. | |
| backward_window (int, optional): Backward window size in attention constraint. | |
| forward_window (int, optional): Forward window size in attetion constraint. | |
| Returns: | |
| Tensor: Monotonic constrained attention energy (1, T). | |
| .. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`: | |
| https://arxiv.org/abs/1710.07654 | |
| """ | |
| if e.size(0) != 1: | |
| raise NotImplementedError("Batch attention constraining is not yet supported.") | |
| backward_idx = last_attended_idx - backward_window | |
| forward_idx = last_attended_idx + forward_window | |
| if backward_idx > 0: | |
| e[:, :backward_idx] = -float("inf") | |
| if forward_idx < e.size(1): | |
| e[:, forward_idx:] = -float("inf") | |
| return e | |
| class NoAtt(torch.nn.Module): | |
| """No attention""" | |
| def __init__(self): | |
| super(NoAtt, self).__init__() | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.c = None | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.c = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): | |
| """NoAtt forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: dummy (does not use) | |
| :param torch.Tensor att_prev: dummy (does not use) | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: previous attention weights | |
| :rtype: torch.Tensor | |
| """ | |
| batch = len(enc_hs_pad) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # initialize attention weight with uniform dist. | |
| if att_prev is None: | |
| # if no bias, 0 0-pad goes 0 | |
| mask = 1.0 - make_pad_mask(enc_hs_len).float() | |
| att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1) | |
| att_prev = att_prev.to(self.enc_h) | |
| self.c = torch.sum( | |
| self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1 | |
| ) | |
| return self.c, att_prev | |
| class AttDot(torch.nn.Module): | |
| """Dot product attention | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int att_dim: attention dimension | |
| :param bool han_mode: flag to swith on mode of hierarchical attention | |
| and not store pre_compute_enc_h | |
| """ | |
| def __init__(self, eprojs, dunits, att_dim, han_mode=False): | |
| super(AttDot, self).__init__() | |
| self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
| self.mlp_dec = torch.nn.Linear(dunits, att_dim) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.att_dim = att_dim | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): | |
| """AttDot forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: dummy (does not use) | |
| :param torch.Tensor att_prev: dummy (does not use) | |
| :param float scaling: scaling parameter before applying softmax | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: previous attention weight (B x T_max) | |
| :rtype: torch.Tensor | |
| """ | |
| batch = enc_hs_pad.size(0) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_enc_h = torch.tanh(self.mlp_enc(self.enc_h)) | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| e = torch.sum( | |
| self.pre_compute_enc_h | |
| * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim), | |
| dim=2, | |
| ) # utt x frame | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w = F.softmax(scaling * e, dim=1) | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
| return c, w | |
| class AttAdd(torch.nn.Module): | |
| """Additive attention | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int att_dim: attention dimension | |
| :param bool han_mode: flag to swith on mode of hierarchical attention | |
| and not store pre_compute_enc_h | |
| """ | |
| def __init__(self, eprojs, dunits, att_dim, han_mode=False): | |
| super(AttAdd, self).__init__() | |
| self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
| self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
| self.gvec = torch.nn.Linear(att_dim, 1) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.att_dim = att_dim | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): | |
| """AttAdd forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param torch.Tensor att_prev: dummy (does not use) | |
| :param float scaling: scaling parameter before applying softmax | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: previous attention weights (B x T_max) | |
| :rtype: torch.Tensor | |
| """ | |
| batch = len(enc_hs_pad) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| # dec_z_tiled: utt x frame x att_dim | |
| dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
| # dot with gvec | |
| # utt x frame x att_dim -> utt x frame | |
| e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w = F.softmax(scaling * e, dim=1) | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
| return c, w | |
| class AttLoc(torch.nn.Module): | |
| """location-aware attention module. | |
| Reference: Attention-Based Models for Speech Recognition | |
| (https://arxiv.org/pdf/1506.07503.pdf) | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int att_dim: attention dimension | |
| :param int aconv_chans: # channels of attention convolution | |
| :param int aconv_filts: filter size of attention convolution | |
| :param bool han_mode: flag to swith on mode of hierarchical attention | |
| and not store pre_compute_enc_h | |
| """ | |
| def __init__( | |
| self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False | |
| ): | |
| super(AttLoc, self).__init__() | |
| self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
| self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
| self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
| self.loc_conv = torch.nn.Conv2d( | |
| 1, | |
| aconv_chans, | |
| (1, 2 * aconv_filts + 1), | |
| padding=(0, aconv_filts), | |
| bias=False, | |
| ) | |
| self.gvec = torch.nn.Linear(att_dim, 1) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.att_dim = att_dim | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| def forward( | |
| self, | |
| enc_hs_pad, | |
| enc_hs_len, | |
| dec_z, | |
| att_prev, | |
| scaling=2.0, | |
| last_attended_idx=None, | |
| backward_window=1, | |
| forward_window=3, | |
| ): | |
| """Calculate AttLoc forward propagation. | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param torch.Tensor att_prev: previous attention weight (B x T_max) | |
| :param float scaling: scaling parameter before applying softmax | |
| :param torch.Tensor forward_window: | |
| forward window size when constraining attention | |
| :param int last_attended_idx: index of the inputs of the last attended | |
| :param int backward_window: backward window size in attention constraint | |
| :param int forward_window: forward window size in attetion constraint | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: previous attention weights (B x T_max) | |
| :rtype: torch.Tensor | |
| """ | |
| batch = len(enc_hs_pad) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| # initialize attention weight with uniform dist. | |
| if att_prev is None: | |
| # if no bias, 0 0-pad goes 0 | |
| att_prev = 1.0 - make_pad_mask(enc_hs_len).to( | |
| device=dec_z.device, dtype=dec_z.dtype | |
| ) | |
| att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) | |
| # att_prev: utt x frame -> utt x 1 x 1 x frame | |
| # -> utt x att_conv_chans x 1 x frame | |
| att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) | |
| # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans | |
| att_conv = att_conv.squeeze(2).transpose(1, 2) | |
| # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
| att_conv = self.mlp_att(att_conv) | |
| # dec_z_tiled: utt x frame x att_dim | |
| dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
| # dot with gvec | |
| # utt x frame x att_dim -> utt x frame | |
| e = self.gvec( | |
| torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) | |
| ).squeeze(2) | |
| # NOTE: consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| # apply monotonic attention constraint (mainly for TTS) | |
| if last_attended_idx is not None: | |
| e = _apply_attention_constraint( | |
| e, last_attended_idx, backward_window, forward_window | |
| ) | |
| w = F.softmax(scaling * e, dim=1) | |
| # weighted sum over flames | |
| # utt x hdim | |
| c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
| return c, w | |
| class AttCov(torch.nn.Module): | |
| """Coverage mechanism attention | |
| Reference: Get To The Point: Summarization with Pointer-Generator Network | |
| (https://arxiv.org/abs/1704.04368) | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int att_dim: attention dimension | |
| :param bool han_mode: flag to swith on mode of hierarchical attention | |
| and not store pre_compute_enc_h | |
| """ | |
| def __init__(self, eprojs, dunits, att_dim, han_mode=False): | |
| super(AttCov, self).__init__() | |
| self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
| self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
| self.wvec = torch.nn.Linear(1, att_dim) | |
| self.gvec = torch.nn.Linear(att_dim, 1) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.att_dim = att_dim | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): | |
| """AttCov forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param list att_prev_list: list of previous attention weight | |
| :param float scaling: scaling parameter before applying softmax | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: list of previous attention weights | |
| :rtype: list | |
| """ | |
| batch = len(enc_hs_pad) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| # initialize attention weight with uniform dist. | |
| if att_prev_list is None: | |
| # if no bias, 0 0-pad goes 0 | |
| att_prev_list = to_device( | |
| enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()) | |
| ) | |
| att_prev_list = [ | |
| att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1) | |
| ] | |
| # att_prev_list: L' * [B x T] => cov_vec B x T | |
| cov_vec = sum(att_prev_list) | |
| # cov_vec: B x T => B x T x 1 => B x T x att_dim | |
| cov_vec = self.wvec(cov_vec.unsqueeze(-1)) | |
| # dec_z_tiled: utt x frame x att_dim | |
| dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
| # dot with gvec | |
| # utt x frame x att_dim -> utt x frame | |
| e = self.gvec( | |
| torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled) | |
| ).squeeze(2) | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w = F.softmax(scaling * e, dim=1) | |
| att_prev_list += [w] | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
| return c, att_prev_list | |
| class AttLoc2D(torch.nn.Module): | |
| """2D location-aware attention | |
| This attention is an extended version of location aware attention. | |
| It take not only one frame before attention weights, | |
| but also earlier frames into account. | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int att_dim: attention dimension | |
| :param int aconv_chans: # channels of attention convolution | |
| :param int aconv_filts: filter size of attention convolution | |
| :param int att_win: attention window size (default=5) | |
| :param bool han_mode: | |
| flag to swith on mode of hierarchical attention and not store pre_compute_enc_h | |
| """ | |
| def __init__( | |
| self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False | |
| ): | |
| super(AttLoc2D, self).__init__() | |
| self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
| self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
| self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
| self.loc_conv = torch.nn.Conv2d( | |
| 1, | |
| aconv_chans, | |
| (att_win, 2 * aconv_filts + 1), | |
| padding=(0, aconv_filts), | |
| bias=False, | |
| ) | |
| self.gvec = torch.nn.Linear(att_dim, 1) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.att_dim = att_dim | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.aconv_chans = aconv_chans | |
| self.att_win = att_win | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): | |
| """AttLoc2D forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param torch.Tensor att_prev: previous attention weight (B x att_win x T_max) | |
| :param float scaling: scaling parameter before applying softmax | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: previous attention weights (B x att_win x T_max) | |
| :rtype: torch.Tensor | |
| """ | |
| batch = len(enc_hs_pad) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| # initialize attention weight with uniform dist. | |
| if att_prev is None: | |
| # B * [Li x att_win] | |
| # if no bias, 0 0-pad goes 0 | |
| att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) | |
| att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) | |
| att_prev = att_prev.unsqueeze(1).expand(-1, self.att_win, -1) | |
| # att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax | |
| att_conv = self.loc_conv(att_prev.unsqueeze(1)) | |
| # att_conv: B x C x 1 x Tmax -> B x Tmax x C | |
| att_conv = att_conv.squeeze(2).transpose(1, 2) | |
| # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
| att_conv = self.mlp_att(att_conv) | |
| # dec_z_tiled: utt x frame x att_dim | |
| dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
| # dot with gvec | |
| # utt x frame x att_dim -> utt x frame | |
| e = self.gvec( | |
| torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) | |
| ).squeeze(2) | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w = F.softmax(scaling * e, dim=1) | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
| # update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax | |
| # -> B x att_win x Tmax | |
| att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1) | |
| att_prev = att_prev[:, 1:] | |
| return c, att_prev | |
| class AttLocRec(torch.nn.Module): | |
| """location-aware recurrent attention | |
| This attention is an extended version of location aware attention. | |
| With the use of RNN, | |
| it take the effect of the history of attention weights into account. | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int att_dim: attention dimension | |
| :param int aconv_chans: # channels of attention convolution | |
| :param int aconv_filts: filter size of attention convolution | |
| :param bool han_mode: | |
| flag to swith on mode of hierarchical attention and not store pre_compute_enc_h | |
| """ | |
| def __init__( | |
| self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False | |
| ): | |
| super(AttLocRec, self).__init__() | |
| self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
| self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
| self.loc_conv = torch.nn.Conv2d( | |
| 1, | |
| aconv_chans, | |
| (1, 2 * aconv_filts + 1), | |
| padding=(0, aconv_filts), | |
| bias=False, | |
| ) | |
| self.att_lstm = torch.nn.LSTMCell(aconv_chans, att_dim, bias=False) | |
| self.gvec = torch.nn.Linear(att_dim, 1) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.att_dim = att_dim | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0): | |
| """AttLocRec forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param tuple att_prev_states: previous attention weight and lstm states | |
| ((B, T_max), ((B, att_dim), (B, att_dim))) | |
| :param float scaling: scaling parameter before applying softmax | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: previous attention weights and lstm states (w, (hx, cx)) | |
| ((B, T_max), ((B, att_dim), (B, att_dim))) | |
| :rtype: tuple | |
| """ | |
| batch = len(enc_hs_pad) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| if att_prev_states is None: | |
| # initialize attention weight with uniform dist. | |
| # if no bias, 0 0-pad goes 0 | |
| att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) | |
| att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) | |
| # initialize lstm states | |
| att_h = enc_hs_pad.new_zeros(batch, self.att_dim) | |
| att_c = enc_hs_pad.new_zeros(batch, self.att_dim) | |
| att_states = (att_h, att_c) | |
| else: | |
| att_prev = att_prev_states[0] | |
| att_states = att_prev_states[1] | |
| # B x 1 x 1 x T -> B x C x 1 x T | |
| att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) | |
| # apply non-linear | |
| att_conv = F.relu(att_conv) | |
| # B x C x 1 x T -> B x C x 1 x 1 -> B x C | |
| att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1) | |
| att_h, att_c = self.att_lstm(att_conv, att_states) | |
| # dec_z_tiled: utt x frame x att_dim | |
| dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
| # dot with gvec | |
| # utt x frame x att_dim -> utt x frame | |
| e = self.gvec( | |
| torch.tanh(att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled) | |
| ).squeeze(2) | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w = F.softmax(scaling * e, dim=1) | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
| return c, (w, (att_h, att_c)) | |
| class AttCovLoc(torch.nn.Module): | |
| """Coverage mechanism location aware attention | |
| This attention is a combination of coverage and location-aware attentions. | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int att_dim: attention dimension | |
| :param int aconv_chans: # channels of attention convolution | |
| :param int aconv_filts: filter size of attention convolution | |
| :param bool han_mode: | |
| flag to swith on mode of hierarchical attention and not store pre_compute_enc_h | |
| """ | |
| def __init__( | |
| self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False | |
| ): | |
| super(AttCovLoc, self).__init__() | |
| self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
| self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
| self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
| self.loc_conv = torch.nn.Conv2d( | |
| 1, | |
| aconv_chans, | |
| (1, 2 * aconv_filts + 1), | |
| padding=(0, aconv_filts), | |
| bias=False, | |
| ) | |
| self.gvec = torch.nn.Linear(att_dim, 1) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.att_dim = att_dim | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.aconv_chans = aconv_chans | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): | |
| """AttCovLoc forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param list att_prev_list: list of previous attention weight | |
| :param float scaling: scaling parameter before applying softmax | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: list of previous attention weights | |
| :rtype: list | |
| """ | |
| batch = len(enc_hs_pad) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| # initialize attention weight with uniform dist. | |
| if att_prev_list is None: | |
| # if no bias, 0 0-pad goes 0 | |
| mask = 1.0 - make_pad_mask(enc_hs_len).float() | |
| att_prev_list = [ | |
| to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) | |
| ] | |
| # att_prev_list: L' * [B x T] => cov_vec B x T | |
| cov_vec = sum(att_prev_list) | |
| # cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T | |
| att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length)) | |
| # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans | |
| att_conv = att_conv.squeeze(2).transpose(1, 2) | |
| # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
| att_conv = self.mlp_att(att_conv) | |
| # dec_z_tiled: utt x frame x att_dim | |
| dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
| # dot with gvec | |
| # utt x frame x att_dim -> utt x frame | |
| e = self.gvec( | |
| torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) | |
| ).squeeze(2) | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w = F.softmax(scaling * e, dim=1) | |
| att_prev_list += [w] | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
| return c, att_prev_list | |
| class AttMultiHeadDot(torch.nn.Module): | |
| """Multi head dot product attention | |
| Reference: Attention is all you need | |
| (https://arxiv.org/abs/1706.03762) | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int aheads: # heads of multi head attention | |
| :param int att_dim_k: dimension k in multi head attention | |
| :param int att_dim_v: dimension v in multi head attention | |
| :param bool han_mode: flag to swith on mode of hierarchical attention | |
| and not store pre_compute_k and pre_compute_v | |
| """ | |
| def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): | |
| super(AttMultiHeadDot, self).__init__() | |
| self.mlp_q = torch.nn.ModuleList() | |
| self.mlp_k = torch.nn.ModuleList() | |
| self.mlp_v = torch.nn.ModuleList() | |
| for _ in six.moves.range(aheads): | |
| self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] | |
| self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] | |
| self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] | |
| self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.aheads = aheads | |
| self.att_dim_k = att_dim_k | |
| self.att_dim_v = att_dim_v | |
| self.scaling = 1.0 / math.sqrt(att_dim_k) | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_k = None | |
| self.pre_compute_v = None | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_k = None | |
| self.pre_compute_v = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): | |
| """AttMultiHeadDot forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param torch.Tensor att_prev: dummy (does not use) | |
| :return: attention weighted encoder state (B x D_enc) | |
| :rtype: torch.Tensor | |
| :return: list of previous attention weight (B x T_max) * aheads | |
| :rtype: list | |
| """ | |
| batch = enc_hs_pad.size(0) | |
| # pre-compute all k and v outside the decoder loop | |
| if self.pre_compute_k is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_k = [ | |
| torch.tanh(self.mlp_k[h](self.enc_h)) | |
| for h in six.moves.range(self.aheads) | |
| ] | |
| if self.pre_compute_v is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_v = [ | |
| self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) | |
| ] | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| c = [] | |
| w = [] | |
| for h in six.moves.range(self.aheads): | |
| e = torch.sum( | |
| self.pre_compute_k[h] | |
| * torch.tanh(self.mlp_q[h](dec_z)).view(batch, 1, self.att_dim_k), | |
| dim=2, | |
| ) # utt x frame | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w += [F.softmax(self.scaling * e, dim=1)] | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c += [ | |
| torch.sum( | |
| self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 | |
| ) | |
| ] | |
| # concat all of c | |
| c = self.mlp_o(torch.cat(c, dim=1)) | |
| return c, w | |
| class AttMultiHeadAdd(torch.nn.Module): | |
| """Multi head additive attention | |
| Reference: Attention is all you need | |
| (https://arxiv.org/abs/1706.03762) | |
| This attention is multi head attention using additive attention for each head. | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int aheads: # heads of multi head attention | |
| :param int att_dim_k: dimension k in multi head attention | |
| :param int att_dim_v: dimension v in multi head attention | |
| :param bool han_mode: flag to swith on mode of hierarchical attention | |
| and not store pre_compute_k and pre_compute_v | |
| """ | |
| def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): | |
| super(AttMultiHeadAdd, self).__init__() | |
| self.mlp_q = torch.nn.ModuleList() | |
| self.mlp_k = torch.nn.ModuleList() | |
| self.mlp_v = torch.nn.ModuleList() | |
| self.gvec = torch.nn.ModuleList() | |
| for _ in six.moves.range(aheads): | |
| self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] | |
| self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] | |
| self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] | |
| self.gvec += [torch.nn.Linear(att_dim_k, 1)] | |
| self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.aheads = aheads | |
| self.att_dim_k = att_dim_k | |
| self.att_dim_v = att_dim_v | |
| self.scaling = 1.0 / math.sqrt(att_dim_k) | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_k = None | |
| self.pre_compute_v = None | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_k = None | |
| self.pre_compute_v = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): | |
| """AttMultiHeadAdd forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param torch.Tensor att_prev: dummy (does not use) | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: list of previous attention weight (B x T_max) * aheads | |
| :rtype: list | |
| """ | |
| batch = enc_hs_pad.size(0) | |
| # pre-compute all k and v outside the decoder loop | |
| if self.pre_compute_k is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_k = [ | |
| self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) | |
| ] | |
| if self.pre_compute_v is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_v = [ | |
| self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) | |
| ] | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| c = [] | |
| w = [] | |
| for h in six.moves.range(self.aheads): | |
| e = self.gvec[h]( | |
| torch.tanh( | |
| self.pre_compute_k[h] | |
| + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) | |
| ) | |
| ).squeeze(2) | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w += [F.softmax(self.scaling * e, dim=1)] | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c += [ | |
| torch.sum( | |
| self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 | |
| ) | |
| ] | |
| # concat all of c | |
| c = self.mlp_o(torch.cat(c, dim=1)) | |
| return c, w | |
| class AttMultiHeadLoc(torch.nn.Module): | |
| """Multi head location based attention | |
| Reference: Attention is all you need | |
| (https://arxiv.org/abs/1706.03762) | |
| This attention is multi head attention using location-aware attention for each head. | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int aheads: # heads of multi head attention | |
| :param int att_dim_k: dimension k in multi head attention | |
| :param int att_dim_v: dimension v in multi head attention | |
| :param int aconv_chans: # channels of attention convolution | |
| :param int aconv_filts: filter size of attention convolution | |
| :param bool han_mode: flag to swith on mode of hierarchical attention | |
| and not store pre_compute_k and pre_compute_v | |
| """ | |
| def __init__( | |
| self, | |
| eprojs, | |
| dunits, | |
| aheads, | |
| att_dim_k, | |
| att_dim_v, | |
| aconv_chans, | |
| aconv_filts, | |
| han_mode=False, | |
| ): | |
| super(AttMultiHeadLoc, self).__init__() | |
| self.mlp_q = torch.nn.ModuleList() | |
| self.mlp_k = torch.nn.ModuleList() | |
| self.mlp_v = torch.nn.ModuleList() | |
| self.gvec = torch.nn.ModuleList() | |
| self.loc_conv = torch.nn.ModuleList() | |
| self.mlp_att = torch.nn.ModuleList() | |
| for _ in six.moves.range(aheads): | |
| self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] | |
| self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] | |
| self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] | |
| self.gvec += [torch.nn.Linear(att_dim_k, 1)] | |
| self.loc_conv += [ | |
| torch.nn.Conv2d( | |
| 1, | |
| aconv_chans, | |
| (1, 2 * aconv_filts + 1), | |
| padding=(0, aconv_filts), | |
| bias=False, | |
| ) | |
| ] | |
| self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] | |
| self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.aheads = aheads | |
| self.att_dim_k = att_dim_k | |
| self.att_dim_v = att_dim_v | |
| self.scaling = 1.0 / math.sqrt(att_dim_k) | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_k = None | |
| self.pre_compute_v = None | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_k = None | |
| self.pre_compute_v = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): | |
| """AttMultiHeadLoc forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param torch.Tensor att_prev: | |
| list of previous attention weight (B x T_max) * aheads | |
| :param float scaling: scaling parameter before applying softmax | |
| :return: attention weighted encoder state (B x D_enc) | |
| :rtype: torch.Tensor | |
| :return: list of previous attention weight (B x T_max) * aheads | |
| :rtype: list | |
| """ | |
| batch = enc_hs_pad.size(0) | |
| # pre-compute all k and v outside the decoder loop | |
| if self.pre_compute_k is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_k = [ | |
| self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) | |
| ] | |
| if self.pre_compute_v is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_v = [ | |
| self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) | |
| ] | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| if att_prev is None: | |
| att_prev = [] | |
| for _ in six.moves.range(self.aheads): | |
| # if no bias, 0 0-pad goes 0 | |
| mask = 1.0 - make_pad_mask(enc_hs_len).float() | |
| att_prev += [ | |
| to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) | |
| ] | |
| c = [] | |
| w = [] | |
| for h in six.moves.range(self.aheads): | |
| att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) | |
| att_conv = att_conv.squeeze(2).transpose(1, 2) | |
| att_conv = self.mlp_att[h](att_conv) | |
| e = self.gvec[h]( | |
| torch.tanh( | |
| self.pre_compute_k[h] | |
| + att_conv | |
| + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) | |
| ) | |
| ).squeeze(2) | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w += [F.softmax(scaling * e, dim=1)] | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c += [ | |
| torch.sum( | |
| self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 | |
| ) | |
| ] | |
| # concat all of c | |
| c = self.mlp_o(torch.cat(c, dim=1)) | |
| return c, w | |
| class AttMultiHeadMultiResLoc(torch.nn.Module): | |
| """Multi head multi resolution location based attention | |
| Reference: Attention is all you need | |
| (https://arxiv.org/abs/1706.03762) | |
| This attention is multi head attention using location-aware attention for each head. | |
| Furthermore, it uses different filter size for each head. | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int aheads: # heads of multi head attention | |
| :param int att_dim_k: dimension k in multi head attention | |
| :param int att_dim_v: dimension v in multi head attention | |
| :param int aconv_chans: maximum # channels of attention convolution | |
| each head use #ch = aconv_chans * (head + 1) / aheads | |
| e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100 | |
| :param int aconv_filts: filter size of attention convolution | |
| :param bool han_mode: flag to swith on mode of hierarchical attention | |
| and not store pre_compute_k and pre_compute_v | |
| """ | |
| def __init__( | |
| self, | |
| eprojs, | |
| dunits, | |
| aheads, | |
| att_dim_k, | |
| att_dim_v, | |
| aconv_chans, | |
| aconv_filts, | |
| han_mode=False, | |
| ): | |
| super(AttMultiHeadMultiResLoc, self).__init__() | |
| self.mlp_q = torch.nn.ModuleList() | |
| self.mlp_k = torch.nn.ModuleList() | |
| self.mlp_v = torch.nn.ModuleList() | |
| self.gvec = torch.nn.ModuleList() | |
| self.loc_conv = torch.nn.ModuleList() | |
| self.mlp_att = torch.nn.ModuleList() | |
| for h in six.moves.range(aheads): | |
| self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] | |
| self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] | |
| self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] | |
| self.gvec += [torch.nn.Linear(att_dim_k, 1)] | |
| afilts = aconv_filts * (h + 1) // aheads | |
| self.loc_conv += [ | |
| torch.nn.Conv2d( | |
| 1, aconv_chans, (1, 2 * afilts + 1), padding=(0, afilts), bias=False | |
| ) | |
| ] | |
| self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] | |
| self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.aheads = aheads | |
| self.att_dim_k = att_dim_k | |
| self.att_dim_v = att_dim_v | |
| self.scaling = 1.0 / math.sqrt(att_dim_k) | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_k = None | |
| self.pre_compute_v = None | |
| self.mask = None | |
| self.han_mode = han_mode | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_k = None | |
| self.pre_compute_v = None | |
| self.mask = None | |
| def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): | |
| """AttMultiHeadMultiResLoc forward | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param torch.Tensor att_prev: list of previous attention weight | |
| (B x T_max) * aheads | |
| :return: attention weighted encoder state (B x D_enc) | |
| :rtype: torch.Tensor | |
| :return: list of previous attention weight (B x T_max) * aheads | |
| :rtype: list | |
| """ | |
| batch = enc_hs_pad.size(0) | |
| # pre-compute all k and v outside the decoder loop | |
| if self.pre_compute_k is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_k = [ | |
| self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) | |
| ] | |
| if self.pre_compute_v is None or self.han_mode: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_v = [ | |
| self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) | |
| ] | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| if att_prev is None: | |
| att_prev = [] | |
| for _ in six.moves.range(self.aheads): | |
| # if no bias, 0 0-pad goes 0 | |
| mask = 1.0 - make_pad_mask(enc_hs_len).float() | |
| att_prev += [ | |
| to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) | |
| ] | |
| c = [] | |
| w = [] | |
| for h in six.moves.range(self.aheads): | |
| att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) | |
| att_conv = att_conv.squeeze(2).transpose(1, 2) | |
| att_conv = self.mlp_att[h](att_conv) | |
| e = self.gvec[h]( | |
| torch.tanh( | |
| self.pre_compute_k[h] | |
| + att_conv | |
| + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) | |
| ) | |
| ).squeeze(2) | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| w += [F.softmax(self.scaling * e, dim=1)] | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c += [ | |
| torch.sum( | |
| self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 | |
| ) | |
| ] | |
| # concat all of c | |
| c = self.mlp_o(torch.cat(c, dim=1)) | |
| return c, w | |
| class AttForward(torch.nn.Module): | |
| """Forward attention module. | |
| Reference: | |
| Forward attention in sequence-to-sequence acoustic modeling for speech synthesis | |
| (https://arxiv.org/pdf/1807.06736.pdf) | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int att_dim: attention dimension | |
| :param int aconv_chans: # channels of attention convolution | |
| :param int aconv_filts: filter size of attention convolution | |
| """ | |
| def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts): | |
| super(AttForward, self).__init__() | |
| self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
| self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
| self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
| self.loc_conv = torch.nn.Conv2d( | |
| 1, | |
| aconv_chans, | |
| (1, 2 * aconv_filts + 1), | |
| padding=(0, aconv_filts), | |
| bias=False, | |
| ) | |
| self.gvec = torch.nn.Linear(att_dim, 1) | |
| self.dunits = dunits | |
| self.eprojs = eprojs | |
| self.att_dim = att_dim | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| def reset(self): | |
| """reset states""" | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| def forward( | |
| self, | |
| enc_hs_pad, | |
| enc_hs_len, | |
| dec_z, | |
| att_prev, | |
| scaling=1.0, | |
| last_attended_idx=None, | |
| backward_window=1, | |
| forward_window=3, | |
| ): | |
| """Calculate AttForward forward propagation. | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
| :param torch.Tensor att_prev: attention weights of previous step | |
| :param float scaling: scaling parameter before applying softmax | |
| :param int last_attended_idx: index of the inputs of the last attended | |
| :param int backward_window: backward window size in attention constraint | |
| :param int forward_window: forward window size in attetion constraint | |
| :return: attention weighted encoder state (B, D_enc) | |
| :rtype: torch.Tensor | |
| :return: previous attention weights (B x T_max) | |
| :rtype: torch.Tensor | |
| """ | |
| batch = len(enc_hs_pad) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| if att_prev is None: | |
| # initial attention will be [1, 0, 0, ...] | |
| att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) | |
| att_prev[:, 0] = 1.0 | |
| # att_prev: utt x frame -> utt x 1 x 1 x frame | |
| # -> utt x att_conv_chans x 1 x frame | |
| att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) | |
| # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans | |
| att_conv = att_conv.squeeze(2).transpose(1, 2) | |
| # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
| att_conv = self.mlp_att(att_conv) | |
| # dec_z_tiled: utt x frame x att_dim | |
| dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1) | |
| # dot with gvec | |
| # utt x frame x att_dim -> utt x frame | |
| e = self.gvec( | |
| torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv) | |
| ).squeeze(2) | |
| # NOTE: consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| # apply monotonic attention constraint (mainly for TTS) | |
| if last_attended_idx is not None: | |
| e = _apply_attention_constraint( | |
| e, last_attended_idx, backward_window, forward_window | |
| ) | |
| w = F.softmax(scaling * e, dim=1) | |
| # forward attention | |
| att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] | |
| w = (att_prev + att_prev_shift) * w | |
| # NOTE: clamp is needed to avoid nan gradient | |
| w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c = torch.sum(self.enc_h * w.unsqueeze(-1), dim=1) | |
| return c, w | |
| class AttForwardTA(torch.nn.Module): | |
| """Forward attention with transition agent module. | |
| Reference: | |
| Forward attention in sequence-to-sequence acoustic modeling for speech synthesis | |
| (https://arxiv.org/pdf/1807.06736.pdf) | |
| :param int eunits: # units of encoder | |
| :param int dunits: # units of decoder | |
| :param int att_dim: attention dimension | |
| :param int aconv_chans: # channels of attention convolution | |
| :param int aconv_filts: filter size of attention convolution | |
| :param int odim: output dimension | |
| """ | |
| def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim): | |
| super(AttForwardTA, self).__init__() | |
| self.mlp_enc = torch.nn.Linear(eunits, att_dim) | |
| self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
| self.mlp_ta = torch.nn.Linear(eunits + dunits + odim, 1) | |
| self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
| self.loc_conv = torch.nn.Conv2d( | |
| 1, | |
| aconv_chans, | |
| (1, 2 * aconv_filts + 1), | |
| padding=(0, aconv_filts), | |
| bias=False, | |
| ) | |
| self.gvec = torch.nn.Linear(att_dim, 1) | |
| self.dunits = dunits | |
| self.eunits = eunits | |
| self.att_dim = att_dim | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| self.trans_agent_prob = 0.5 | |
| def reset(self): | |
| self.h_length = None | |
| self.enc_h = None | |
| self.pre_compute_enc_h = None | |
| self.mask = None | |
| self.trans_agent_prob = 0.5 | |
| def forward( | |
| self, | |
| enc_hs_pad, | |
| enc_hs_len, | |
| dec_z, | |
| att_prev, | |
| out_prev, | |
| scaling=1.0, | |
| last_attended_idx=None, | |
| backward_window=1, | |
| forward_window=3, | |
| ): | |
| """Calculate AttForwardTA forward propagation. | |
| :param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits) | |
| :param list enc_hs_len: padded encoder hidden state length (B) | |
| :param torch.Tensor dec_z: decoder hidden state (B, dunits) | |
| :param torch.Tensor att_prev: attention weights of previous step | |
| :param torch.Tensor out_prev: decoder outputs of previous step (B, odim) | |
| :param float scaling: scaling parameter before applying softmax | |
| :param int last_attended_idx: index of the inputs of the last attended | |
| :param int backward_window: backward window size in attention constraint | |
| :param int forward_window: forward window size in attetion constraint | |
| :return: attention weighted encoder state (B, dunits) | |
| :rtype: torch.Tensor | |
| :return: previous attention weights (B, Tmax) | |
| :rtype: torch.Tensor | |
| """ | |
| batch = len(enc_hs_pad) | |
| # pre-compute all h outside the decoder loop | |
| if self.pre_compute_enc_h is None: | |
| self.enc_h = enc_hs_pad # utt x frame x hdim | |
| self.h_length = self.enc_h.size(1) | |
| # utt x frame x att_dim | |
| self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
| if dec_z is None: | |
| dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
| else: | |
| dec_z = dec_z.view(batch, self.dunits) | |
| if att_prev is None: | |
| # initial attention will be [1, 0, 0, ...] | |
| att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) | |
| att_prev[:, 0] = 1.0 | |
| # att_prev: utt x frame -> utt x 1 x 1 x frame | |
| # -> utt x att_conv_chans x 1 x frame | |
| att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) | |
| # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans | |
| att_conv = att_conv.squeeze(2).transpose(1, 2) | |
| # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
| att_conv = self.mlp_att(att_conv) | |
| # dec_z_tiled: utt x frame x att_dim | |
| dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
| # dot with gvec | |
| # utt x frame x att_dim -> utt x frame | |
| e = self.gvec( | |
| torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) | |
| ).squeeze(2) | |
| # NOTE consider zero padding when compute w. | |
| if self.mask is None: | |
| self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
| e.masked_fill_(self.mask, -float("inf")) | |
| # apply monotonic attention constraint (mainly for TTS) | |
| if last_attended_idx is not None: | |
| e = _apply_attention_constraint( | |
| e, last_attended_idx, backward_window, forward_window | |
| ) | |
| w = F.softmax(scaling * e, dim=1) | |
| # forward attention | |
| att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] | |
| w = ( | |
| self.trans_agent_prob * att_prev | |
| + (1 - self.trans_agent_prob) * att_prev_shift | |
| ) * w | |
| # NOTE: clamp is needed to avoid nan gradient | |
| w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) | |
| # weighted sum over flames | |
| # utt x hdim | |
| # NOTE use bmm instead of sum(*) | |
| c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
| # update transition agent prob | |
| self.trans_agent_prob = torch.sigmoid( | |
| self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)) | |
| ) | |
| return c, w | |
| def att_for(args, num_att=1, han_mode=False): | |
| """Instantiates an attention module given the program arguments | |
| :param Namespace args: The arguments | |
| :param int num_att: number of attention modules | |
| (in multi-speaker case, it can be 2 or more) | |
| :param bool han_mode: switch on/off mode of hierarchical attention network (HAN) | |
| :rtype torch.nn.Module | |
| :return: The attention module | |
| """ | |
| att_list = torch.nn.ModuleList() | |
| num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility | |
| aheads = getattr(args, "aheads", None) | |
| awin = getattr(args, "awin", None) | |
| aconv_chans = getattr(args, "aconv_chans", None) | |
| aconv_filts = getattr(args, "aconv_filts", None) | |
| if num_encs == 1: | |
| for i in range(num_att): | |
| att = initial_att( | |
| args.atype, | |
| args.eprojs, | |
| args.dunits, | |
| aheads, | |
| args.adim, | |
| awin, | |
| aconv_chans, | |
| aconv_filts, | |
| ) | |
| att_list.append(att) | |
| elif num_encs > 1: # no multi-speaker mode | |
| if han_mode: | |
| att = initial_att( | |
| args.han_type, | |
| args.eprojs, | |
| args.dunits, | |
| args.han_heads, | |
| args.han_dim, | |
| args.han_win, | |
| args.han_conv_chans, | |
| args.han_conv_filts, | |
| han_mode=True, | |
| ) | |
| return att | |
| else: | |
| att_list = torch.nn.ModuleList() | |
| for idx in range(num_encs): | |
| att = initial_att( | |
| args.atype[idx], | |
| args.eprojs, | |
| args.dunits, | |
| aheads[idx], | |
| args.adim[idx], | |
| awin[idx], | |
| aconv_chans[idx], | |
| aconv_filts[idx], | |
| ) | |
| att_list.append(att) | |
| else: | |
| raise ValueError( | |
| "Number of encoders needs to be more than one. {}".format(num_encs) | |
| ) | |
| return att_list | |
| def initial_att( | |
| atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts, han_mode=False | |
| ): | |
| """Instantiates a single attention module | |
| :param str atype: attention type | |
| :param int eprojs: # projection-units of encoder | |
| :param int dunits: # units of decoder | |
| :param int aheads: # heads of multi head attention | |
| :param int adim: attention dimension | |
| :param int awin: attention window size | |
| :param int aconv_chans: # channels of attention convolution | |
| :param int aconv_filts: filter size of attention convolution | |
| :param bool han_mode: flag to swith on mode of hierarchical attention | |
| :return: The attention module | |
| """ | |
| if atype == "noatt": | |
| att = NoAtt() | |
| elif atype == "dot": | |
| att = AttDot(eprojs, dunits, adim, han_mode) | |
| elif atype == "add": | |
| att = AttAdd(eprojs, dunits, adim, han_mode) | |
| elif atype == "location": | |
| att = AttLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) | |
| elif atype == "location2d": | |
| att = AttLoc2D(eprojs, dunits, adim, awin, aconv_chans, aconv_filts, han_mode) | |
| elif atype == "location_recurrent": | |
| att = AttLocRec(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) | |
| elif atype == "coverage": | |
| att = AttCov(eprojs, dunits, adim, han_mode) | |
| elif atype == "coverage_location": | |
| att = AttCovLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) | |
| elif atype == "multi_head_dot": | |
| att = AttMultiHeadDot(eprojs, dunits, aheads, adim, adim, han_mode) | |
| elif atype == "multi_head_add": | |
| att = AttMultiHeadAdd(eprojs, dunits, aheads, adim, adim, han_mode) | |
| elif atype == "multi_head_loc": | |
| att = AttMultiHeadLoc( | |
| eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode | |
| ) | |
| elif atype == "multi_head_multi_res_loc": | |
| att = AttMultiHeadMultiResLoc( | |
| eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode | |
| ) | |
| return att | |
| def att_to_numpy(att_ws, att): | |
| """Converts attention weights to a numpy array given the attention | |
| :param list att_ws: The attention weights | |
| :param torch.nn.Module att: The attention | |
| :rtype: np.ndarray | |
| :return: The numpy array of the attention weights | |
| """ | |
| # convert to numpy array with the shape (B, Lmax, Tmax) | |
| if isinstance(att, AttLoc2D): | |
| # att_ws => list of previous concate attentions | |
| att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy() | |
| elif isinstance(att, (AttCov, AttCovLoc)): | |
| # att_ws => list of list of previous attentions | |
| att_ws = ( | |
| torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy() | |
| ) | |
| elif isinstance(att, AttLocRec): | |
| # att_ws => list of tuple of attention and hidden states | |
| att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy() | |
| elif isinstance( | |
| att, | |
| (AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc), | |
| ): | |
| # att_ws => list of list of each head attention | |
| n_heads = len(att_ws[0]) | |
| att_ws_sorted_by_head = [] | |
| for h in six.moves.range(n_heads): | |
| att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1) | |
| att_ws_sorted_by_head += [att_ws_head] | |
| att_ws = torch.stack(att_ws_sorted_by_head, dim=1).cpu().numpy() | |
| else: | |
| # att_ws => list of attentions | |
| att_ws = torch.stack(att_ws, dim=1).cpu().numpy() | |
| return att_ws | |