# Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from .position_encoding import PositionEmbeddingSine from .transformer import Transformer class StandardTransformerDecoder(nn.Module): def __init__( self, in_channels, num_classes, mask_classification=True, hidden_dim=256, num_queries=100, nheads=8, dropout=0.0, dim_feedforward=2048, enc_layers=0, dec_layers=10, pre_norm=False, deep_supervision=True, mask_dim=256, enforce_input_project=False ): super().__init__() self.mask_classification = mask_classification # positional encoding N_steps = hidden_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) transformer = Transformer( d_model=hidden_dim, dropout=dropout, nhead=nheads, dim_feedforward=dim_feedforward, num_encoder_layers=enc_layers, num_decoder_layers=dec_layers, normalize_before=pre_norm, return_intermediate_dec=deep_supervision, ) self.num_queries = num_queries self.transformer = transformer hidden_dim = transformer.d_model self.query_embed = nn.Embedding(num_queries, hidden_dim) if in_channels != hidden_dim or enforce_input_project: self.input_proj = nn.Conv3d(in_channels, hidden_dim, kernel_size=1) weight_init.c2_xavier_fill(self.input_proj) else: self.input_proj = nn.Sequential() self.aux_loss = deep_supervision # output FFNs if self.mask_classification: self.class_embed = nn.Linear(hidden_dim, num_classes + 1) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) def forward(self, x, mask_features, mask=None): if mask is not None: mask = F.interpolate(mask[None].float(), size=x.shape[-2:]).to(torch.bool)[0] pos = self.pe_layer(x, mask) src = x hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos) if self.mask_classification: outputs_class = self.class_embed(hs) out = {"pred_logits": outputs_class[-1]} else: out = {} if self.aux_loss: # [l, bs, queries, embed] mask_embed = self.mask_embed(hs) outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features) out["pred_masks"] = outputs_seg_masks[-1] out["aux_outputs"] = self._set_aux_loss( outputs_class if self.mask_classification else None, outputs_seg_masks ) else: # FIXME h_boxes takes the last one computed, keep this in mind # [bs, queries, embed] mask_embed = self.mask_embed(hs[-1]) outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) out["pred_masks"] = outputs_seg_masks return out @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_seg_masks): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. if self.mask_classification: return [ {"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) ] else: return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] class MLP(nn.Module): """Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) ) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x