Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| import torch.nn as nn | |
| from mmcv.runner import ModuleList | |
| from mmocr.models.builder import ENCODERS | |
| from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding, | |
| SatrnEncoderLayer) | |
| from .base_encoder import BaseEncoder | |
| class SatrnEncoder(BaseEncoder): | |
| """Implement encoder for SATRN, see `SATRN. | |
| <https://arxiv.org/abs/1910.04396>`_. | |
| Args: | |
| n_layers (int): Number of attention layers. | |
| n_head (int): Number of parallel attention heads. | |
| d_k (int): Dimension of the key vector. | |
| d_v (int): Dimension of the value vector. | |
| d_model (int): Dimension :math:`D_m` of the input from previous model. | |
| n_position (int): Length of the positional encoding vector. Must be | |
| greater than ``max_seq_len``. | |
| d_inner (int): Hidden dimension of feedforward layers. | |
| dropout (float): Dropout rate. | |
| init_cfg (dict or list[dict], optional): Initialization configs. | |
| """ | |
| def __init__(self, | |
| n_layers=12, | |
| n_head=8, | |
| d_k=64, | |
| d_v=64, | |
| d_model=512, | |
| n_position=100, | |
| d_inner=256, | |
| dropout=0.1, | |
| init_cfg=None, | |
| **kwargs): | |
| super().__init__(init_cfg=init_cfg) | |
| self.d_model = d_model | |
| self.position_enc = Adaptive2DPositionalEncoding( | |
| d_hid=d_model, | |
| n_height=n_position, | |
| n_width=n_position, | |
| dropout=dropout) | |
| self.layer_stack = ModuleList([ | |
| SatrnEncoderLayer( | |
| d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |
| for _ in range(n_layers) | |
| ]) | |
| self.layer_norm = nn.LayerNorm(d_model) | |
| def forward(self, feat, img_metas=None): | |
| """ | |
| Args: | |
| feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. | |
| img_metas (dict): A dict that contains meta information of input | |
| images. Preferably with the key ``valid_ratio``. | |
| Returns: | |
| Tensor: A tensor of shape :math:`(N, T, D_m)`. | |
| """ | |
| valid_ratios = [1.0 for _ in range(feat.size(0))] | |
| if img_metas is not None: | |
| valid_ratios = [ | |
| img_meta.get('valid_ratio', 1.0) for img_meta in img_metas | |
| ] | |
| feat += self.position_enc(feat) | |
| n, c, h, w = feat.size() | |
| mask = feat.new_zeros((n, h, w)) | |
| for i, valid_ratio in enumerate(valid_ratios): | |
| valid_width = min(w, math.ceil(w * valid_ratio)) | |
| mask[i, :, :valid_width] = 1 | |
| mask = mask.view(n, h * w) | |
| feat = feat.view(n, c, h * w) | |
| output = feat.permute(0, 2, 1).contiguous() | |
| for enc_layer in self.layer_stack: | |
| output = enc_layer(output, h, w, mask) | |
| output = self.layer_norm(output) | |
| return output | |