File size: 4,954 Bytes
ba4c371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f0644
ba4c371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import logging
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union

import fvcore.nn.weight_init as weight_init
from torch import nn
from torch.nn import functional as F
import torch
from detectron2.config import configurable
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
import torch.utils.checkpoint as cp
from .convnext import ConvNextBlock
from einops import rearrange,repeat

@SEM_SEG_HEADS_REGISTRY.register()
class MASKAdapterHead(nn.Module):

    @configurable
    def __init__(
        self,
        clip_model_name,
        mask_in_chans: int,
        num_channels: int,
        use_checkpoint: bool,
        num_output_maps: int,
    ):
        """
        NOTE: this interface is experimental.
        Args:
            input_shape: shapes (channels and stride) of the input features
            num_classes: number of classes to predict
            pixel_decoder: the pixel decoder module
            loss_weight: loss weight
            ignore_value: category id to be ignored during training.
            transformer_predictor: the transformer decoder that makes prediction
            transformer_in_feature: input feature name to the transformer_predictor
        """
        super().__init__()
        self.use_checkpoint = use_checkpoint
        
        if '_base' in clip_model_name:
            clip_dim = 640
        elif '_large' in clip_model_name:
            clip_dim = 768
        
        self.fuse = nn.Conv2d(clip_dim, num_channels, 1)
                
        self.cnext1 = ConvNextBlock(num_channels)
        
        self.cnext2 = ConvNextBlock(num_channels)
        
        self.cnext3 = ConvNextBlock(num_channels)
        
        self.norm = nn.LayerNorm(num_channels)
        self.final = nn.Conv2d(num_channels, num_output_maps, 1)
        
        self.mask_downscaling = nn.Sequential(
            nn.Conv2d(1, mask_in_chans // 4, kernel_size=3, stride=2, padding=1),
            LayerNorm2d(mask_in_chans // 4),
            nn.GELU(),
            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=3, stride=2, padding=1),
            LayerNorm2d(mask_in_chans),
            nn.GELU(),
            nn.Conv2d(mask_in_chans, clip_dim, kernel_size=1),
        )
        

    @classmethod
    def from_config(cls, cfg):

        return {
            "clip_model_name": cfg.MODEL.FC_CLIP.CLIP_MODEL_NAME,
            "mask_in_chans": cfg.MODEL.MASK_ADAPTER.MASK_IN_CHANNELS,
            "num_channels": cfg.MODEL.MASK_ADAPTER.NUM_CHANNELS,
            "use_checkpoint": cfg.MODEL.MASK_ADAPTER.USE_CHECKPOINT,
            "num_output_maps": cfg.MODEL.MASK_ADAPTER.NUM_OUTPUT_MAPS,
        }

    def forward(self, clip_feature, masks):

        
        N = masks.size(1)
        masks = rearrange(masks, 'B N H W -> (B N) H W').unsqueeze(dim=1)
        
        clip_feature = repeat(clip_feature, "B C H W -> (B N) C H W", N=N)
        
        H,W = clip_feature.shape[-2:]
        masks = F.interpolate(masks, size=(H*4,W*4),
                                                mode='bilinear', align_corners=False)
        masks = self.mask_downscaling(masks)
        
        outputs = clip_feature + masks
        
        def _inner_forward(outputs):
            outputs = self.fuse(outputs)
        
            outputs = self.cnext1(outputs)
            
            outputs = self.cnext2(outputs)
            
            outputs = self.cnext3(outputs)
            
            outputs = outputs.permute(0, 2, 3, 1) 
            outputs = self.norm(outputs.contiguous())
            outputs = outputs.permute(0, 3, 1, 2) 
            
            outputs = self.final(outputs.contiguous()) 
            
            outputs = rearrange(outputs, '(B N) C H W -> B (N C) H W',N=N)
    
            return outputs

        if self.use_checkpoint and self.training:
            outputs = cp.checkpoint(_inner_forward, outputs,use_reentrant=False)
        else:
            outputs = _inner_forward(outputs)
        return outputs

def build_mask_adapter(cfg,name):
    return SEM_SEG_HEADS_REGISTRY.get(name)(cfg)

# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119  # noqa
class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x