Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule | |
| from mmocr.models.builder import DECODERS | |
| from mmocr.models.common.modules import PositionalEncoding | |
| from .base_decoder import BaseDecoder | |
| class ABIVisionDecoder(BaseDecoder): | |
| """Converts visual features into text characters. | |
| Implementation of VisionEncoder in | |
| `ABINet <https://arxiv.org/abs/1910.04396>`_. | |
| Args: | |
| in_channels (int): Number of channels :math:`E` of input vector. | |
| num_channels (int): Number of channels of hidden vectors in mini U-Net. | |
| h (int): Height :math:`H` of input image features. | |
| w (int): Width :math:`W` of input image features. | |
| in_channels (int): Number of channels of input image features. | |
| num_channels (int): Number of channels of hidden vectors in mini U-Net. | |
| attn_height (int): Height :math:`H` of input image features. | |
| attn_width (int): Width :math:`W` of input image features. | |
| attn_mode (str): Upsampling mode for :obj:`torch.nn.Upsample` in mini | |
| U-Net. | |
| max_seq_len (int): Maximum text sequence length :math:`T`. | |
| num_chars (int): Number of text characters :math:`C`. | |
| init_cfg (dict): Specifies the initialization method for model layers. | |
| """ | |
| def __init__(self, | |
| in_channels=512, | |
| num_channels=64, | |
| attn_height=8, | |
| attn_width=32, | |
| attn_mode='nearest', | |
| max_seq_len=40, | |
| num_chars=90, | |
| init_cfg=dict(type='Xavier', layer='Conv2d'), | |
| **kwargs): | |
| super().__init__(init_cfg=init_cfg) | |
| self.max_seq_len = max_seq_len | |
| # For mini-Unet | |
| self.k_encoder = nn.Sequential( | |
| self._encoder_layer(in_channels, num_channels, stride=(1, 2)), | |
| self._encoder_layer(num_channels, num_channels, stride=(2, 2)), | |
| self._encoder_layer(num_channels, num_channels, stride=(2, 2)), | |
| self._encoder_layer(num_channels, num_channels, stride=(2, 2))) | |
| self.k_decoder = nn.Sequential( | |
| self._decoder_layer( | |
| num_channels, num_channels, scale_factor=2, mode=attn_mode), | |
| self._decoder_layer( | |
| num_channels, num_channels, scale_factor=2, mode=attn_mode), | |
| self._decoder_layer( | |
| num_channels, num_channels, scale_factor=2, mode=attn_mode), | |
| self._decoder_layer( | |
| num_channels, | |
| in_channels, | |
| size=(attn_height, attn_width), | |
| mode=attn_mode)) | |
| self.pos_encoder = PositionalEncoding(in_channels, max_seq_len) | |
| self.project = nn.Linear(in_channels, in_channels) | |
| self.cls = nn.Linear(in_channels, num_chars) | |
| def forward_train(self, | |
| feat, | |
| out_enc=None, | |
| targets_dict=None, | |
| img_metas=None): | |
| """ | |
| Args: | |
| feat (Tensor): Image features of shape (N, E, H, W). | |
| Returns: | |
| dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``. | |
| - | feature (Tensor): Shape (N, T, E). Raw visual features for | |
| language decoder. | |
| - | logits (Tensor): Shape (N, T, C). The raw logits for | |
| characters. | |
| - | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result | |
| for vision-language aligner. | |
| """ | |
| # Position Attention | |
| N, E, H, W = feat.size() | |
| k, v = feat, feat # (N, E, H, W) | |
| # Apply mini U-Net on k | |
| features = [] | |
| for i in range(len(self.k_encoder)): | |
| k = self.k_encoder[i](k) | |
| features.append(k) | |
| for i in range(len(self.k_decoder) - 1): | |
| k = self.k_decoder[i](k) | |
| k = k + features[len(self.k_decoder) - 2 - i] | |
| k = self.k_decoder[-1](k) | |
| # q = positional encoding | |
| zeros = feat.new_zeros((N, self.max_seq_len, E)) # (N, T, E) | |
| q = self.pos_encoder(zeros) # (N, T, E) | |
| q = self.project(q) # (N, T, E) | |
| # Attention encoding | |
| attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) | |
| attn_scores = attn_scores / (E**0.5) | |
| attn_scores = torch.softmax(attn_scores, dim=-1) | |
| v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) | |
| attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) | |
| logits = self.cls(attn_vecs) | |
| result = { | |
| 'feature': attn_vecs, | |
| 'logits': logits, | |
| 'attn_scores': attn_scores.view(N, -1, H, W) | |
| } | |
| return result | |
| def forward_test(self, feat, out_enc=None, img_metas=None): | |
| return self.forward_train(feat, out_enc=out_enc, img_metas=img_metas) | |
| def _encoder_layer(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1): | |
| return ConvModule( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU')) | |
| def _decoder_layer(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| mode='nearest', | |
| scale_factor=None, | |
| size=None): | |
| align_corners = None if mode == 'nearest' else True | |
| return nn.Sequential( | |
| nn.Upsample( | |
| size=size, | |
| scale_factor=scale_factor, | |
| mode=mode, | |
| align_corners=align_corners), | |
| ConvModule( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'))) | |