File size: 4,978 Bytes
e8f2571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Tuple
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.utils import MultiConfig, OptConfigType
@MODELS.register_module()
class FusedSemanticHead(BaseModule):
r"""Multi-level fused semantic segmentation head.
.. code-block:: none
in_1 -> 1x1 conv ---
|
in_2 -> 1x1 conv -- |
||
in_3 -> 1x1 conv - ||
||| /-> 1x1 conv (mask prediction)
in_4 -> 1x1 conv -----> 3x3 convs (*4)
| \-> 1x1 conv (feature)
in_5 -> 1x1 conv ---
""" # noqa: W605
def __init__(
self,
num_ins: int,
fusion_level: int,
seg_scale_factor=1 / 8,
num_convs: int = 4,
in_channels: int = 256,
conv_out_channels: int = 256,
num_classes: int = 183,
conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None,
ignore_label: int = None,
loss_weight: float = None,
loss_seg: ConfigDict = dict(
type='CrossEntropyLoss', ignore_index=255, loss_weight=0.2),
init_cfg: MultiConfig = dict(
type='Kaiming', override=dict(name='conv_logits'))
) -> None:
super().__init__(init_cfg=init_cfg)
self.num_ins = num_ins
self.fusion_level = fusion_level
self.seg_scale_factor = seg_scale_factor
self.num_convs = num_convs
self.in_channels = in_channels
self.conv_out_channels = conv_out_channels
self.num_classes = num_classes
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.fp16_enabled = False
self.lateral_convs = nn.ModuleList()
for i in range(self.num_ins):
self.lateral_convs.append(
ConvModule(
self.in_channels,
self.in_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=False))
self.convs = nn.ModuleList()
for i in range(self.num_convs):
in_channels = self.in_channels if i == 0 else conv_out_channels
self.convs.append(
ConvModule(
in_channels,
conv_out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.conv_embedding = ConvModule(
conv_out_channels,
conv_out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)
if ignore_label:
loss_seg['ignore_index'] = ignore_label
if loss_weight:
loss_seg['loss_weight'] = loss_weight
if ignore_label or loss_weight:
warnings.warn('``ignore_label`` and ``loss_weight`` would be '
'deprecated soon. Please set ``ingore_index`` and '
'``loss_weight`` in ``loss_seg`` instead.')
self.criterion = MODELS.build(loss_seg)
def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]:
"""Forward function.
Args:
feats (tuple[Tensor]): Multi scale feature maps.
Returns:
tuple[Tensor]:
- mask_preds (Tensor): Predicted mask logits.
- x (Tensor): Fused feature.
"""
x = self.lateral_convs[self.fusion_level](feats[self.fusion_level])
fused_size = tuple(x.shape[-2:])
for i, feat in enumerate(feats):
if i != self.fusion_level:
feat = F.interpolate(
feat, size=fused_size, mode='bilinear', align_corners=True)
# fix runtime error of "+=" inplace operation in PyTorch 1.10
x = x + self.lateral_convs[i](feat)
for i in range(self.num_convs):
x = self.convs[i](x)
mask_preds = self.conv_logits(x)
x = self.conv_embedding(x)
return mask_preds, x
def loss(self, mask_preds: Tensor, labels: Tensor) -> Tensor:
"""Loss function.
Args:
mask_preds (Tensor): Predicted mask logits.
labels (Tensor): Ground truth.
Returns:
Tensor: Semantic segmentation loss.
"""
labels = F.interpolate(
labels.float(), scale_factor=self.seg_scale_factor, mode='nearest')
labels = labels.squeeze(1).long()
loss_semantic_seg = self.criterion(mask_preds, labels)
return loss_semantic_seg
|