Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,954 Bytes
ba4c371 e8f0644 ba4c371 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import logging
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union
import fvcore.nn.weight_init as weight_init
from torch import nn
from torch.nn import functional as F
import torch
from detectron2.config import configurable
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
import torch.utils.checkpoint as cp
from .convnext import ConvNextBlock
from einops import rearrange,repeat
@SEM_SEG_HEADS_REGISTRY.register()
class MASKAdapterHead(nn.Module):
@configurable
def __init__(
self,
clip_model_name,
mask_in_chans: int,
num_channels: int,
use_checkpoint: bool,
num_output_maps: int,
):
"""
NOTE: this interface is experimental.
Args:
input_shape: shapes (channels and stride) of the input features
num_classes: number of classes to predict
pixel_decoder: the pixel decoder module
loss_weight: loss weight
ignore_value: category id to be ignored during training.
transformer_predictor: the transformer decoder that makes prediction
transformer_in_feature: input feature name to the transformer_predictor
"""
super().__init__()
self.use_checkpoint = use_checkpoint
if '_base' in clip_model_name:
clip_dim = 640
elif '_large' in clip_model_name:
clip_dim = 768
self.fuse = nn.Conv2d(clip_dim, num_channels, 1)
self.cnext1 = ConvNextBlock(num_channels)
self.cnext2 = ConvNextBlock(num_channels)
self.cnext3 = ConvNextBlock(num_channels)
self.norm = nn.LayerNorm(num_channels)
self.final = nn.Conv2d(num_channels, num_output_maps, 1)
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=3, stride=2, padding=1),
LayerNorm2d(mask_in_chans // 4),
nn.GELU(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=3, stride=2, padding=1),
LayerNorm2d(mask_in_chans),
nn.GELU(),
nn.Conv2d(mask_in_chans, clip_dim, kernel_size=1),
)
@classmethod
def from_config(cls, cfg):
return {
"clip_model_name": cfg.MODEL.FC_CLIP.CLIP_MODEL_NAME,
"mask_in_chans": cfg.MODEL.MASK_ADAPTER.MASK_IN_CHANNELS,
"num_channels": cfg.MODEL.MASK_ADAPTER.NUM_CHANNELS,
"use_checkpoint": cfg.MODEL.MASK_ADAPTER.USE_CHECKPOINT,
"num_output_maps": cfg.MODEL.MASK_ADAPTER.NUM_OUTPUT_MAPS,
}
def forward(self, clip_feature, masks):
N = masks.size(1)
masks = rearrange(masks, 'B N H W -> (B N) H W').unsqueeze(dim=1)
clip_feature = repeat(clip_feature, "B C H W -> (B N) C H W", N=N)
H,W = clip_feature.shape[-2:]
masks = F.interpolate(masks, size=(H*4,W*4),
mode='bilinear', align_corners=False)
masks = self.mask_downscaling(masks)
outputs = clip_feature + masks
def _inner_forward(outputs):
outputs = self.fuse(outputs)
outputs = self.cnext1(outputs)
outputs = self.cnext2(outputs)
outputs = self.cnext3(outputs)
outputs = outputs.permute(0, 2, 3, 1)
outputs = self.norm(outputs.contiguous())
outputs = outputs.permute(0, 3, 1, 2)
outputs = self.final(outputs.contiguous())
outputs = rearrange(outputs, '(B N) C H W -> B (N C) H W',N=N)
return outputs
if self.use_checkpoint and self.training:
outputs = cp.checkpoint(_inner_forward, outputs,use_reentrant=False)
else:
outputs = _inner_forward(outputs)
return outputs
def build_mask_adapter(cfg,name):
return SEM_SEG_HEADS_REGISTRY.get(name)(cfg)
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x |