|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, List, Union
|
|
import torch
|
|
from torch import nn, Tensor
|
|
from torch.cuda.amp import autocast
|
|
|
|
from ...utils.utils import MLP, _get_clones, _get_activation_fn, gen_sineembed_for_position, inverse_sigmoid
|
|
from ..pixel_decoder.ops.modules import MSDeformAttn
|
|
|
|
|
|
class TransformerDecoder(nn.Module):
|
|
|
|
def __init__(self, decoder_layer, num_layers, norm=None,
|
|
return_intermediate=False,
|
|
d_model=256, query_dim=4,
|
|
modulate_hw_attn=True,
|
|
num_feature_levels=1,
|
|
deformable_decoder=True,
|
|
decoder_query_perturber=None,
|
|
dec_layer_number=None,
|
|
rm_dec_query_scale=True,
|
|
dec_layer_share=False,
|
|
dec_layer_dropout_prob=None,
|
|
cross_track_layer = False,
|
|
n_levels = None,
|
|
n_heads = None,
|
|
n_points = None,
|
|
):
|
|
super().__init__()
|
|
if num_layers > 0:
|
|
self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share)
|
|
else:
|
|
self.layers = []
|
|
self.num_layers = num_layers
|
|
self.norm = norm
|
|
self.return_intermediate = return_intermediate
|
|
assert return_intermediate, "support return_intermediate only"
|
|
self.query_dim = query_dim
|
|
assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
|
|
self.num_feature_levels = num_feature_levels
|
|
|
|
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
|
|
if not deformable_decoder:
|
|
self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
|
|
else:
|
|
self.query_pos_sine_scale = None
|
|
|
|
if rm_dec_query_scale:
|
|
self.query_scale = None
|
|
else:
|
|
raise NotImplementedError
|
|
self.query_scale = MLP(d_model, d_model, d_model, 2)
|
|
self.bbox_embed = None
|
|
self.class_embed = None
|
|
|
|
self.d_model = d_model
|
|
self.modulate_hw_attn = modulate_hw_attn
|
|
self.deformable_decoder = deformable_decoder
|
|
|
|
if not deformable_decoder and modulate_hw_attn:
|
|
self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
|
|
else:
|
|
self.ref_anchor_head = None
|
|
|
|
self.decoder_query_perturber = decoder_query_perturber
|
|
self.box_pred_damping = None
|
|
|
|
self.dec_layer_number = dec_layer_number
|
|
if dec_layer_number is not None:
|
|
assert isinstance(dec_layer_number, list)
|
|
assert len(dec_layer_number) == num_layers
|
|
|
|
|
|
self.dec_layer_dropout_prob = dec_layer_dropout_prob
|
|
if dec_layer_dropout_prob is not None:
|
|
assert isinstance(dec_layer_dropout_prob, list)
|
|
assert len(dec_layer_dropout_prob) == num_layers
|
|
for i in dec_layer_dropout_prob:
|
|
assert 0.0 <= i <= 1.0
|
|
if cross_track_layer:
|
|
self.cross_track_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
|
self.cross_track = True
|
|
else:
|
|
self.cross_track = False
|
|
|
|
self._reset_parameters()
|
|
|
|
def _reset_parameters(self):
|
|
for p in self.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
for m in self.modules():
|
|
if isinstance(m, MSDeformAttn):
|
|
m._reset_parameters()
|
|
@staticmethod
|
|
def with_pos_embed(tensor, pos):
|
|
return tensor if pos is None else tensor + pos
|
|
|
|
|
|
def forward(self, tgt, memory,
|
|
tgt_mask: Optional[Tensor] = None,
|
|
memory_mask: Optional[Tensor] = None,
|
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
memory_key_padding_mask: Optional[Tensor] = None,
|
|
pos: Optional[Tensor] = None,
|
|
refpoints_unsigmoid: Optional[Tensor] = None,
|
|
|
|
level_start_index: Optional[Tensor] = None,
|
|
spatial_shapes: Optional[Tensor] = None,
|
|
valid_ratios: Optional[Tensor] = None,
|
|
task = None,
|
|
extra = None,
|
|
|
|
):
|
|
"""
|
|
Input:
|
|
- tgt: nq, bs, d_model
|
|
- memory: hw, bs, d_model
|
|
- pos: hw, bs, d_model
|
|
- refpoints_unsigmoid: nq, bs, 2/4
|
|
- valid_ratios/spatial_shapes: bs, nlevel, 2
|
|
"""
|
|
output = tgt
|
|
device = tgt.device
|
|
|
|
intermediate = []
|
|
reference_points = refpoints_unsigmoid.sigmoid().to(device)
|
|
ref_points = [reference_points]
|
|
|
|
for layer_id, layer in enumerate(self.layers):
|
|
|
|
if self.training and self.decoder_query_perturber is not None and layer_id != 0:
|
|
reference_points = self.decoder_query_perturber(reference_points)
|
|
|
|
reference_points_input = reference_points[:, :, None] \
|
|
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
|
|
query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :])
|
|
|
|
raw_query_pos = self.ref_point_head(query_sine_embed)
|
|
pos_scale = self.query_scale(output) if self.query_scale is not None else 1
|
|
query_pos = pos_scale * raw_query_pos
|
|
|
|
output = layer(
|
|
tgt=output,
|
|
tgt_query_pos=query_pos,
|
|
tgt_query_sine_embed=query_sine_embed,
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
tgt_reference_points=reference_points_input,
|
|
|
|
memory=memory,
|
|
memory_key_padding_mask=memory_key_padding_mask,
|
|
memory_level_start_index=level_start_index,
|
|
memory_spatial_shapes=spatial_shapes,
|
|
memory_pos=pos,
|
|
|
|
self_attn_mask=tgt_mask,
|
|
cross_attn_mask=memory_mask,
|
|
task = task,
|
|
extra = extra,
|
|
layer_id = layer_id,
|
|
)
|
|
|
|
|
|
if self.bbox_embed is not None:
|
|
reference_before_sigmoid = inverse_sigmoid(reference_points)
|
|
delta_unsig = self.bbox_embed[layer_id](output).to(device)
|
|
outputs_unsig = delta_unsig + reference_before_sigmoid
|
|
new_reference_points = outputs_unsig.sigmoid()
|
|
|
|
reference_points = new_reference_points.detach()
|
|
|
|
ref_points.append(new_reference_points)
|
|
|
|
intermediate.append(self.norm(output))
|
|
|
|
|
|
if self.cross_track:
|
|
tgt_track = self.cross_track_attn(self.with_pos_embed(output, query_pos).transpose(0, 1),
|
|
reference_points_input.transpose(0, 1).contiguous(),
|
|
memory.transpose(0, 1), spatial_shapes, level_start_index,
|
|
memory_key_padding_mask).transpose(0, 1)
|
|
tgt_track = tgt_track + output
|
|
tgt_track = tgt_track.transpose(0, 1)
|
|
else:
|
|
tgt_track = None
|
|
|
|
return [
|
|
[itm_out.transpose(0, 1) for itm_out in intermediate],
|
|
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points], tgt_track
|
|
]
|
|
|
|
|
|
class DeformableTransformerDecoderLayer(nn.Module):
|
|
|
|
def __init__(self, d_model=256, d_ffn=1024,
|
|
dropout=0.1, activation="relu",
|
|
n_levels=4, n_heads=8, n_points=4,
|
|
use_deformable_box_attn=False,
|
|
key_aware_type=None,
|
|
):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
|
|
if use_deformable_box_attn:
|
|
raise NotImplementedError
|
|
else:
|
|
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
|
|
self.linear1 = nn.Linear(d_model, d_ffn)
|
|
self.activation = _get_activation_fn(activation)
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
self.linear2 = nn.Linear(d_ffn, d_model)
|
|
self.dropout4 = nn.Dropout(dropout)
|
|
self.norm3 = nn.LayerNorm(d_model)
|
|
|
|
self.key_aware_type = key_aware_type
|
|
self.key_aware_proj = None
|
|
|
|
def rm_self_attn_modules(self):
|
|
self.self_attn = None
|
|
self.dropout2 = None
|
|
self.norm2 = None
|
|
|
|
@staticmethod
|
|
def with_pos_embed(tensor, pos):
|
|
return tensor if pos is None else tensor + pos
|
|
|
|
def forward_ffn(self, tgt):
|
|
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
|
tgt = tgt + self.dropout4(tgt2)
|
|
tgt = self.norm3(tgt)
|
|
return tgt
|
|
|
|
@autocast(enabled=False)
|
|
def forward(self,
|
|
|
|
tgt: Optional[Tensor],
|
|
tgt_query_pos: Optional[Tensor] = None,
|
|
tgt_query_sine_embed: Optional[Tensor] = None,
|
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
tgt_reference_points: Optional[Tensor] = None,
|
|
|
|
|
|
memory: Optional[Tensor] = None,
|
|
memory_key_padding_mask: Optional[Tensor] = None,
|
|
memory_level_start_index: Optional[Tensor] = None,
|
|
memory_spatial_shapes: Optional[Tensor] = None,
|
|
memory_pos: Optional[Tensor] = None,
|
|
|
|
|
|
self_attn_mask: Optional[Tensor] = None,
|
|
cross_attn_mask: Optional[Tensor] = None,
|
|
task = None,
|
|
extra = None,
|
|
layer_id = None,
|
|
):
|
|
"""
|
|
Input:
|
|
- tgt/tgt_query_pos: nq, bs, d_model
|
|
-
|
|
"""
|
|
|
|
|
|
|
|
if task in ['grounding', 'rvos'] or 'visual_prompt_tokens' in extra:
|
|
if self_attn_mask is not None:
|
|
|
|
if 'visual_prompt_tokens' in extra:
|
|
level_index = layer_id % 3
|
|
prompt_tokens = extra['visual_prompt_tokens'][level_index]
|
|
promot_pos = prompt_tokens.detach().clone()
|
|
prompt_mask = extra['visual_prompt_nonzero_mask'][level_index]
|
|
else:
|
|
prompt_tokens = extra['grounding_tokens']
|
|
promot_pos = prompt_tokens.detach().clone()
|
|
prompt_mask = extra['grounding_nonzero_mask']
|
|
ori_size = tgt.shape[0]
|
|
new_mask_size = tgt.shape[0]+prompt_tokens.shape[0]
|
|
new_self_attn_mask = torch.zeros((tgt.shape[1], new_mask_size, new_mask_size), dtype=torch.bool, device=tgt.device)
|
|
|
|
new_self_attn_mask[:,:ori_size,:ori_size] = self_attn_mask.unsqueeze(0).repeat(tgt.shape[1],1,1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_self_attn_mask[:,:ori_size,ori_size:].transpose(1,2)[prompt_mask] = True
|
|
|
|
new_self_attn_mask[:,ori_size:,:ori_size][prompt_mask] = True
|
|
|
|
|
|
new_self_attn_mask = new_self_attn_mask.repeat_interleave(self.n_heads, dim=0)
|
|
else:
|
|
if 'visual_prompt_tokens' in extra:
|
|
level_index = layer_id % 3
|
|
prompt_tokens = extra['visual_prompt_tokens'][level_index]
|
|
promot_pos = prompt_tokens.detach().clone()
|
|
prompt_mask = extra['visual_prompt_nonzero_mask'][level_index]
|
|
else:
|
|
prompt_tokens = extra['grounding_tokens']
|
|
promot_pos = prompt_tokens.detach().clone()
|
|
prompt_mask = extra['grounding_nonzero_mask']
|
|
ori_size = tgt.shape[0]
|
|
new_mask_size = tgt.shape[0]+prompt_tokens.shape[0]
|
|
new_self_attn_mask = torch.zeros((tgt.shape[1], new_mask_size, new_mask_size), dtype=torch.bool, device=tgt.device)
|
|
new_self_attn_mask[:,:ori_size,ori_size:].transpose(1,2)[prompt_mask] = True
|
|
new_self_attn_mask[:,ori_size:,:ori_size][prompt_mask] = True
|
|
new_self_attn_mask = new_self_attn_mask.repeat_interleave(self.n_heads, dim=0)
|
|
|
|
|
|
if self.self_attn is not None:
|
|
tgt = torch.cat([tgt,prompt_tokens],dim=0)
|
|
tgt_query_pos = torch.cat([tgt_query_pos,promot_pos],dim=0)
|
|
q = k = self.with_pos_embed(tgt, tgt_query_pos)
|
|
tgt2 = self.self_attn(q, k, tgt, attn_mask=new_self_attn_mask)[0]
|
|
tgt = tgt + self.dropout2(tgt2)
|
|
tgt = self.norm2(tgt)
|
|
tgt = tgt[:ori_size]
|
|
tgt_query_pos = tgt_query_pos[:ori_size]
|
|
else:
|
|
if self.self_attn is not None:
|
|
q = k = self.with_pos_embed(tgt, tgt_query_pos)
|
|
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
|
|
tgt = tgt + self.dropout2(tgt2)
|
|
tgt = self.norm2(tgt)
|
|
|
|
|
|
if self.key_aware_type is not None:
|
|
if self.key_aware_type == 'mean':
|
|
tgt = tgt + memory.mean(0, keepdim=True)
|
|
elif self.key_aware_type == 'proj_mean':
|
|
tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True)
|
|
else:
|
|
raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type))
|
|
tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
|
|
tgt_reference_points.transpose(0, 1).contiguous(),
|
|
memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index,
|
|
memory_key_padding_mask).transpose(0, 1)
|
|
tgt = tgt + self.dropout1(tgt2)
|
|
tgt = self.norm1(tgt)
|
|
|
|
|
|
tgt = self.forward_ffn(tgt)
|
|
|
|
return tgt
|
|
|
|
|
|
|