Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn.bricks.transformer import BaseTransformerLayer | |
| from mmcv.runner import ModuleList | |
| from mmocr.models.builder import DECODERS | |
| from mmocr.models.common.modules import PositionalEncoding | |
| from .base_decoder import BaseDecoder | |
| class ABILanguageDecoder(BaseDecoder): | |
| r"""Transformer-based language model responsible for spell correction. | |
| Implementation of language model of \ | |
| `ABINet <https://arxiv.org/abs/1910.04396>`_. | |
| Args: | |
| d_model (int): Hidden size of input. | |
| n_head (int): Number of multi-attention heads. | |
| d_inner (int): Hidden size of feedforward network model. | |
| n_layers (int): The number of similar decoding layers. | |
| max_seq_len (int): Maximum text sequence length :math:`T`. | |
| dropout (float): Dropout rate. | |
| detach_tokens (bool): Whether to block the gradient flow at input | |
| tokens. | |
| num_chars (int): Number of text characters :math:`C`. | |
| use_self_attn (bool): If True, use self attention in decoder layers, | |
| otherwise cross attention will be used. | |
| pad_idx (bool): The index of the token indicating the end of output, | |
| which is used to compute the length of output. It is usually the | |
| index of `<EOS>` or `<PAD>` token. | |
| init_cfg (dict): Specifies the initialization method for model layers. | |
| """ | |
| def __init__(self, | |
| d_model=512, | |
| n_head=8, | |
| d_inner=2048, | |
| n_layers=4, | |
| max_seq_len=40, | |
| dropout=0.1, | |
| detach_tokens=True, | |
| num_chars=90, | |
| use_self_attn=False, | |
| pad_idx=0, | |
| init_cfg=None, | |
| **kwargs): | |
| super().__init__(init_cfg=init_cfg) | |
| self.detach_tokens = detach_tokens | |
| self.d_model = d_model | |
| self.max_seq_len = max_seq_len | |
| self.proj = nn.Linear(num_chars, d_model, False) | |
| self.token_encoder = PositionalEncoding( | |
| d_model, n_position=self.max_seq_len, dropout=0.1) | |
| self.pos_encoder = PositionalEncoding( | |
| d_model, n_position=self.max_seq_len) | |
| self.pad_idx = pad_idx | |
| if use_self_attn: | |
| operation_order = ('self_attn', 'norm', 'cross_attn', 'norm', | |
| 'ffn', 'norm') | |
| else: | |
| operation_order = ('cross_attn', 'norm', 'ffn', 'norm') | |
| decoder_layer = BaseTransformerLayer( | |
| operation_order=operation_order, | |
| attn_cfgs=dict( | |
| type='MultiheadAttention', | |
| embed_dims=d_model, | |
| num_heads=n_head, | |
| attn_drop=dropout, | |
| dropout_layer=dict(type='Dropout', drop_prob=dropout), | |
| ), | |
| ffn_cfgs=dict( | |
| type='FFN', | |
| embed_dims=d_model, | |
| feedforward_channels=d_inner, | |
| ffn_drop=dropout, | |
| ), | |
| norm_cfg=dict(type='LN'), | |
| ) | |
| self.decoder_layers = ModuleList( | |
| [copy.deepcopy(decoder_layer) for _ in range(n_layers)]) | |
| self.cls = nn.Linear(d_model, num_chars) | |
| def forward_train(self, feat, logits, targets_dict, img_metas): | |
| """ | |
| Args: | |
| logits (Tensor): Raw language logitis. Shape (N, T, C). | |
| Returns: | |
| A dict with keys ``feature`` and ``logits``. | |
| feature (Tensor): Shape (N, T, E). Raw textual features for vision | |
| language aligner. | |
| logits (Tensor): Shape (N, T, C). The raw logits for characters | |
| after spell correction. | |
| """ | |
| lengths = self._get_length(logits) | |
| lengths.clamp_(2, self.max_seq_len) | |
| tokens = torch.softmax(logits, dim=-1) | |
| if self.detach_tokens: | |
| tokens = tokens.detach() | |
| embed = self.proj(tokens) # (N, T, E) | |
| embed = self.token_encoder(embed) # (N, T, E) | |
| padding_mask = self._get_padding_mask(lengths, self.max_seq_len) | |
| zeros = embed.new_zeros(*embed.shape) | |
| query = self.pos_encoder(zeros) | |
| query = query.permute(1, 0, 2) # (T, N, E) | |
| embed = embed.permute(1, 0, 2) | |
| location_mask = self._get_location_mask(self.max_seq_len, | |
| tokens.device) | |
| output = query | |
| for m in self.decoder_layers: | |
| output = m( | |
| query=output, | |
| key=embed, | |
| value=embed, | |
| attn_masks=location_mask, | |
| key_padding_mask=padding_mask) | |
| output = output.permute(1, 0, 2) # (N, T, E) | |
| logits = self.cls(output) # (N, T, C) | |
| return {'feature': output, 'logits': logits} | |
| def forward_test(self, feat, out_enc, img_metas): | |
| return self.forward_train(feat, out_enc, None, img_metas) | |
| def _get_length(self, logit, dim=-1): | |
| """Greedy decoder to obtain length from logit. | |
| Returns the first location of padding index or the length of the entire | |
| tensor otherwise. | |
| """ | |
| # out as a boolean vector indicating the existence of end token(s) | |
| out = (logit.argmax(dim=-1) == self.pad_idx) | |
| abn = out.any(dim) | |
| # Get the first index of end token | |
| out = ((out.cumsum(dim) == 1) & out).max(dim)[1] | |
| out = out + 1 | |
| out = torch.where(abn, out, out.new_tensor(logit.shape[1])) | |
| return out | |
| def _get_location_mask(seq_len, device=None): | |
| """Generate location masks given input sequence length. | |
| Args: | |
| seq_len (int): The length of input sequence to transformer. | |
| device (torch.device or str, optional): The device on which the | |
| masks will be placed. | |
| Returns: | |
| Tensor: A mask tensor of shape (seq_len, seq_len) with -infs on | |
| diagonal and zeros elsewhere. | |
| """ | |
| mask = torch.eye(seq_len, device=device) | |
| mask = mask.float().masked_fill(mask == 1, float('-inf')) | |
| return mask | |
| def _get_padding_mask(length, max_length): | |
| """Generate padding masks. | |
| Args: | |
| length (Tensor): Shape :math:`(N,)`. | |
| max_length (int): The maximum sequence length :math:`T`. | |
| Returns: | |
| Tensor: A bool tensor of shape :math:`(N, T)` with Trues on | |
| elements located over the length, or Falses elsewhere. | |
| """ | |
| length = length.unsqueeze(-1) | |
| grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) | |
| return grid >= length | |