File size: 4,978 Bytes
e8f2571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Tuple

import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from torch import Tensor

from mmdet.registry import MODELS
from mmdet.utils import MultiConfig, OptConfigType


@MODELS.register_module()
class FusedSemanticHead(BaseModule):
    r"""Multi-level fused semantic segmentation head.

    .. code-block:: none

        in_1 -> 1x1 conv ---
                            |
        in_2 -> 1x1 conv -- |
                           ||
        in_3 -> 1x1 conv - ||
                          |||                  /-> 1x1 conv (mask prediction)
        in_4 -> 1x1 conv -----> 3x3 convs (*4)
                            |                  \-> 1x1 conv (feature)
        in_5 -> 1x1 conv ---
    """  # noqa: W605

    def __init__(
        self,
        num_ins: int,
        fusion_level: int,
        seg_scale_factor=1 / 8,
        num_convs: int = 4,
        in_channels: int = 256,
        conv_out_channels: int = 256,
        num_classes: int = 183,
        conv_cfg: OptConfigType = None,
        norm_cfg: OptConfigType = None,
        ignore_label: int = None,
        loss_weight: float = None,
        loss_seg: ConfigDict = dict(
            type='CrossEntropyLoss', ignore_index=255, loss_weight=0.2),
        init_cfg: MultiConfig = dict(
            type='Kaiming', override=dict(name='conv_logits'))
    ) -> None:
        super().__init__(init_cfg=init_cfg)
        self.num_ins = num_ins
        self.fusion_level = fusion_level
        self.seg_scale_factor = seg_scale_factor
        self.num_convs = num_convs
        self.in_channels = in_channels
        self.conv_out_channels = conv_out_channels
        self.num_classes = num_classes
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.fp16_enabled = False

        self.lateral_convs = nn.ModuleList()
        for i in range(self.num_ins):
            self.lateral_convs.append(
                ConvModule(
                    self.in_channels,
                    self.in_channels,
                    1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                    inplace=False))

        self.convs = nn.ModuleList()
        for i in range(self.num_convs):
            in_channels = self.in_channels if i == 0 else conv_out_channels
            self.convs.append(
                ConvModule(
                    in_channels,
                    conv_out_channels,
                    3,
                    padding=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg))
        self.conv_embedding = ConvModule(
            conv_out_channels,
            conv_out_channels,
            1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg)
        self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)
        if ignore_label:
            loss_seg['ignore_index'] = ignore_label
        if loss_weight:
            loss_seg['loss_weight'] = loss_weight
        if ignore_label or loss_weight:
            warnings.warn('``ignore_label`` and ``loss_weight`` would be '
                          'deprecated soon. Please set ``ingore_index`` and '
                          '``loss_weight`` in ``loss_seg`` instead.')
        self.criterion = MODELS.build(loss_seg)

    def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]:
        """Forward function.

        Args:
            feats (tuple[Tensor]): Multi scale feature maps.

        Returns:
            tuple[Tensor]:

                - mask_preds (Tensor): Predicted mask logits.
                - x (Tensor): Fused feature.
        """
        x = self.lateral_convs[self.fusion_level](feats[self.fusion_level])
        fused_size = tuple(x.shape[-2:])
        for i, feat in enumerate(feats):
            if i != self.fusion_level:
                feat = F.interpolate(
                    feat, size=fused_size, mode='bilinear', align_corners=True)
                # fix runtime error of "+=" inplace operation in PyTorch 1.10
                x = x + self.lateral_convs[i](feat)

        for i in range(self.num_convs):
            x = self.convs[i](x)

        mask_preds = self.conv_logits(x)
        x = self.conv_embedding(x)
        return mask_preds, x

    def loss(self, mask_preds: Tensor, labels: Tensor) -> Tensor:
        """Loss function.

        Args:
            mask_preds (Tensor): Predicted mask logits.
            labels (Tensor): Ground truth.

        Returns:
            Tensor: Semantic segmentation loss.
        """
        labels = F.interpolate(
            labels.float(), scale_factor=self.seg_scale_factor, mode='nearest')
        labels = labels.squeeze(1).long()
        loss_semantic_seg = self.criterion(mask_preds, labels)
        return loss_semantic_seg