Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.runner import ModuleList | |
| from mmocr.models.builder import DECODERS | |
| from mmocr.models.common import PositionalEncoding, TFDecoderLayer | |
| from .base_decoder import BaseDecoder | |
| class NRTRDecoder(BaseDecoder): | |
| """Transformer Decoder block with self attention mechanism. | |
| Args: | |
| n_layers (int): Number of attention layers. | |
| d_embedding (int): Language embedding dimension. | |
| n_head (int): Number of parallel attention heads. | |
| d_k (int): Dimension of the key vector. | |
| d_v (int): Dimension of the value vector. | |
| d_model (int): Dimension :math:`D_m` of the input from previous model. | |
| d_inner (int): Hidden dimension of feedforward layers. | |
| n_position (int): Length of the positional encoding vector. Must be | |
| greater than ``max_seq_len``. | |
| dropout (float): Dropout rate. | |
| num_classes (int): Number of output classes :math:`C`. | |
| max_seq_len (int): Maximum output sequence length :math:`T`. | |
| start_idx (int): The index of `<SOS>`. | |
| padding_idx (int): The index of `<PAD>`. | |
| init_cfg (dict or list[dict], optional): Initialization configs. | |
| Warning: | |
| This decoder will not predict the final class which is assumed to be | |
| `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>` | |
| is also ignored by loss as specified in | |
| :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. | |
| """ | |
| def __init__(self, | |
| n_layers=6, | |
| d_embedding=512, | |
| n_head=8, | |
| d_k=64, | |
| d_v=64, | |
| d_model=512, | |
| d_inner=256, | |
| n_position=200, | |
| dropout=0.1, | |
| num_classes=93, | |
| max_seq_len=40, | |
| start_idx=1, | |
| padding_idx=92, | |
| init_cfg=None, | |
| **kwargs): | |
| super().__init__(init_cfg=init_cfg) | |
| self.padding_idx = padding_idx | |
| self.start_idx = start_idx | |
| self.max_seq_len = max_seq_len | |
| self.trg_word_emb = nn.Embedding( | |
| num_classes, d_embedding, padding_idx=padding_idx) | |
| self.position_enc = PositionalEncoding( | |
| d_embedding, n_position=n_position) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.layer_stack = ModuleList([ | |
| TFDecoderLayer( | |
| d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs) | |
| for _ in range(n_layers) | |
| ]) | |
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
| pred_num_class = num_classes - 1 # ignore padding_idx | |
| self.classifier = nn.Linear(d_model, pred_num_class) | |
| def get_pad_mask(seq, pad_idx): | |
| return (seq != pad_idx).unsqueeze(-2) | |
| def get_subsequent_mask(seq): | |
| """For masking out the subsequent info.""" | |
| len_s = seq.size(1) | |
| subsequent_mask = 1 - torch.triu( | |
| torch.ones((len_s, len_s), device=seq.device), diagonal=1) | |
| subsequent_mask = subsequent_mask.unsqueeze(0).bool() | |
| return subsequent_mask | |
| def _attention(self, trg_seq, src, src_mask=None): | |
| trg_embedding = self.trg_word_emb(trg_seq) | |
| trg_pos_encoded = self.position_enc(trg_embedding) | |
| tgt = self.dropout(trg_pos_encoded) | |
| trg_mask = self.get_pad_mask( | |
| trg_seq, | |
| pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq) | |
| output = tgt | |
| for dec_layer in self.layer_stack: | |
| output = dec_layer( | |
| output, | |
| src, | |
| self_attn_mask=trg_mask, | |
| dec_enc_attn_mask=src_mask) | |
| output = self.layer_norm(output) | |
| return output | |
| def _get_mask(self, logit, img_metas): | |
| valid_ratios = None | |
| if img_metas is not None: | |
| valid_ratios = [ | |
| img_meta.get('valid_ratio', 1.0) for img_meta in img_metas | |
| ] | |
| N, T, _ = logit.size() | |
| mask = None | |
| if valid_ratios is not None: | |
| mask = logit.new_zeros((N, T)) | |
| for i, valid_ratio in enumerate(valid_ratios): | |
| valid_width = min(T, math.ceil(T * valid_ratio)) | |
| mask[i, :valid_width] = 1 | |
| return mask | |
| def forward_train(self, feat, out_enc, targets_dict, img_metas): | |
| r""" | |
| Args: | |
| feat (None): Unused. | |
| out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)` | |
| where :math:`D_m` is ``d_model``. | |
| targets_dict (dict): A dict with the key ``padded_targets``, a | |
| tensor of shape :math:`(N, T)`. Each element is the index of a | |
| character. | |
| img_metas (dict): A dict that contains meta information of input | |
| images. Preferably with the key ``valid_ratio``. | |
| Returns: | |
| Tensor: The raw logit tensor. Shape :math:`(N, T, C)`. | |
| """ | |
| src_mask = self._get_mask(out_enc, img_metas) | |
| targets = targets_dict['padded_targets'].to(out_enc.device) | |
| attn_output = self._attention(targets, out_enc, src_mask=src_mask) | |
| outputs = self.classifier(attn_output) | |
| return outputs | |
| def forward_test(self, feat, out_enc, img_metas): | |
| src_mask = self._get_mask(out_enc, img_metas) | |
| N = out_enc.size(0) | |
| init_target_seq = torch.full((N, self.max_seq_len + 1), | |
| self.padding_idx, | |
| device=out_enc.device, | |
| dtype=torch.long) | |
| # bsz * seq_len | |
| init_target_seq[:, 0] = self.start_idx | |
| outputs = [] | |
| for step in range(0, self.max_seq_len): | |
| decoder_output = self._attention( | |
| init_target_seq, out_enc, src_mask=src_mask) | |
| # bsz * seq_len * C | |
| step_result = F.softmax( | |
| self.classifier(decoder_output[:, step, :]), dim=-1) | |
| # bsz * num_classes | |
| outputs.append(step_result) | |
| _, step_max_index = torch.max(step_result, dim=-1) | |
| init_target_seq[:, step + 1] = step_max_index | |
| outputs = torch.stack(outputs, dim=1) | |
| return outputs | |