File size: 3,209 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from mmpose.models.utils.ops import resize
from mmpose.registry import MODELS


@MODELS.register_module()
class FeatureMapProcessor(nn.Module):
    """A PyTorch module for selecting, concatenating, and rescaling feature
    maps.

    Args:
        select_index (Optional[Union[int, Tuple[int]]], optional): Index or
            indices of feature maps to select. Defaults to None, which means
            all feature maps are used.
        concat (bool, optional): Whether to concatenate the selected feature
            maps. Defaults to False.
        scale_factor (float, optional): The scaling factor to apply to the
            feature maps. Defaults to 1.0.
        apply_relu (bool, optional): Whether to apply ReLU on input feature
            maps. Defaults to False.
        align_corners (bool, optional): Whether to align corners when resizing
            the feature maps. Defaults to False.
    """

    def __init__(
        self,
        select_index: Optional[Union[int, Tuple[int]]] = None,
        concat: bool = False,
        scale_factor: float = 1.0,
        apply_relu: bool = False,
        align_corners: bool = False,
    ):
        super().__init__()

        if isinstance(select_index, int):
            select_index = (select_index, )
        self.select_index = select_index
        self.concat = concat

        assert (
            scale_factor > 0
        ), f'the argument `scale_factor` must be positive, ' \
           f'but got {scale_factor}'
        self.scale_factor = scale_factor
        self.apply_relu = apply_relu
        self.align_corners = align_corners

    def forward(self, inputs: Union[Tensor, Sequence[Tensor]]
                ) -> Union[Tensor, List[Tensor]]:

        if not isinstance(inputs, (tuple, list)):
            sequential_input = False
            inputs = [inputs]
        else:
            sequential_input = True

            if self.select_index is not None:
                inputs = [inputs[i] for i in self.select_index]

            if self.concat:
                inputs = self._concat(inputs)

        if self.apply_relu:
            inputs = [F.relu(x) for x in inputs]

        if self.scale_factor != 1.0:
            inputs = self._rescale(inputs)

        if not sequential_input:
            inputs = inputs[0]

        return inputs

    def _concat(self, inputs: Sequence[Tensor]) -> List[Tensor]:
        size = inputs[0].shape[-2:]
        resized_inputs = [
            resize(
                x,
                size=size,
                mode='bilinear',
                align_corners=self.align_corners) for x in inputs
        ]
        return [torch.cat(resized_inputs, dim=1)]

    def _rescale(self, inputs: Sequence[Tensor]) -> List[Tensor]:
        rescaled_inputs = [
            resize(
                x,
                scale_factor=self.scale_factor,
                mode='bilinear',
                align_corners=self.align_corners,
            ) for x in inputs
        ]
        return rescaled_inputs