|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from timm.models.layers import DropPath |
|
|
|
|
|
|
|
|
|
class VLFuse(torch.nn.Module): |
|
""" |
|
Early Fusion Module |
|
""" |
|
|
|
def __init__(self, ): |
|
super(VLFuse, self).__init__() |
|
self.init_configs() |
|
|
|
|
|
|
|
self.b_attn = BiAttentionBlockForCheckpoint(v_dim=self.img_dim, |
|
l_dim=self.lang_dim, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.n_head, |
|
dropout=0.1, |
|
drop_path=.0, |
|
init_values=1.0 / 6, |
|
) |
|
def init_configs(self, ): |
|
|
|
self.img_dim = 256 |
|
|
|
self.max_query_len = 256 |
|
self.n_layers =1 |
|
|
|
|
|
self.n_head = 8 |
|
self.embed_dim = 2048 |
|
|
|
self.lang_dim = 256 |
|
|
|
def forward(self, x, task=None): |
|
visual_features = x["visual"] |
|
language_dict_features = x["lang"] |
|
|
|
fused_visual_features, language_features = self.b_attn( |
|
visual_features, language_dict_features['hidden'], language_dict_features['masks'], task) |
|
|
|
language_dict_features['hidden'] = language_features |
|
fused_language_dict_features = language_dict_features |
|
|
|
features_dict = {"visual": fused_visual_features, |
|
"lang": fused_language_dict_features} |
|
|
|
return features_dict |
|
|
|
|
|
|
|
def masks_to_boxes(masks): |
|
"""Compute the bounding boxes around the provided masks |
|
|
|
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. |
|
|
|
Returns a [N, 4] tensors, with the boxes in xyxy format |
|
""" |
|
if masks.numel() == 0: |
|
return torch.zeros((0, 4), device=masks.device) |
|
|
|
h, w = masks.shape[-2:] |
|
|
|
y = torch.arange(0, h, dtype=torch.float, device=masks.device) |
|
x = torch.arange(0, w, dtype=torch.float, device=masks.device) |
|
y, x = torch.meshgrid(y, x) |
|
|
|
x_mask = (masks * x.unsqueeze(0)) |
|
x_max = x_mask.flatten(1).max(-1)[0] |
|
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] |
|
|
|
y_mask = (masks * y.unsqueeze(0)) |
|
y_max = y_mask.flatten(1).max(-1)[0] |
|
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] |
|
|
|
return torch.stack([x_min, y_min, x_max, y_max], 1) |
|
|
|
class FeatureFuser(nn.Module): |
|
""" |
|
Feature Fuser for SOT (inspired by CondInst) |
|
""" |
|
def __init__(self, in_channels, channels=256): |
|
super().__init__() |
|
|
|
self.refine = nn.ModuleList() |
|
for in_channel in in_channels: |
|
self.refine.append(nn.Conv2d(in_channel, channels, 3, padding=1)) |
|
|
|
def forward(self, features): |
|
|
|
for i, f in enumerate([-3, -2, -1]): |
|
if i == 0: |
|
x = self.refine[i](features[f]) |
|
else: |
|
x_p = self.refine[i](features[f]) |
|
target_h, target_w = x.size()[2:] |
|
h, w = x_p.size()[2:] |
|
assert target_h % h == 0 |
|
assert target_w % w == 0 |
|
factor_h, factor_w = target_h // h, target_w // w |
|
assert factor_h == factor_w |
|
x_p = aligned_bilinear(x_p, factor_h) |
|
x = x + x_p |
|
return x |
|
|
|
def aligned_bilinear(tensor, factor): |
|
assert tensor.dim() == 4 |
|
assert factor >= 1 |
|
assert int(factor) == factor |
|
|
|
if factor == 1: |
|
return tensor |
|
|
|
h, w = tensor.size()[2:] |
|
tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate") |
|
oh = factor * h + 1 |
|
ow = factor * w + 1 |
|
tensor = F.interpolate( |
|
tensor, size=(oh, ow), |
|
mode='bilinear', |
|
align_corners=True |
|
) |
|
tensor = F.pad( |
|
tensor, pad=(factor // 2, 0, factor // 2, 0), |
|
mode="replicate" |
|
) |
|
|
|
return tensor[:, :, :oh - 1, :ow - 1] |
|
|
|
|
|
|
|
|
|
class BiMultiHeadAttention(nn.Module): |
|
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1): |
|
super(BiMultiHeadAttention, self).__init__() |
|
|
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.head_dim = embed_dim // num_heads |
|
self.v_dim = v_dim |
|
self.l_dim = l_dim |
|
|
|
assert ( |
|
self.head_dim * self.num_heads == self.embed_dim |
|
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." |
|
self.scale = self.head_dim ** (-0.5) |
|
self.dropout = dropout |
|
|
|
self.v_proj = nn.Linear(self.v_dim, self.embed_dim) |
|
self.l_proj = nn.Linear(self.l_dim, self.embed_dim) |
|
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) |
|
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) |
|
|
|
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) |
|
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) |
|
|
|
self.stable_softmax_2d = False |
|
self.clamp_min_for_underflow = True |
|
self.clamp_max_for_overflow = True |
|
|
|
self._reset_parameters() |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def _reset_parameters(self): |
|
nn.init.xavier_uniform_(self.v_proj.weight) |
|
self.v_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.l_proj.weight) |
|
self.l_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.values_v_proj.weight) |
|
self.values_v_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.values_l_proj.weight) |
|
self.values_l_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.out_v_proj.weight) |
|
self.out_v_proj.bias.data.fill_(0) |
|
nn.init.xavier_uniform_(self.out_l_proj.weight) |
|
self.out_l_proj.bias.data.fill_(0) |
|
|
|
def forward(self, v, l, attention_mask_l=None): |
|
bsz, tgt_len, embed_dim = v.size() |
|
|
|
query_states = self.v_proj(v) * self.scale |
|
key_states = self._shape(self.l_proj(l), -1, bsz) |
|
value_v_states = self._shape(self.values_v_proj(v), -1, bsz) |
|
value_l_states = self._shape(self.values_l_proj(l), -1, bsz) |
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
|
key_states = key_states.view(*proj_shape) |
|
value_v_states = value_v_states.view(*proj_shape) |
|
value_l_states = value_l_states.view(*proj_shape) |
|
|
|
src_len = key_states.size(1) |
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
|
|
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" |
|
) |
|
|
|
|
|
|
|
if self.stable_softmax_2d: |
|
attn_weights = attn_weights - attn_weights.max() |
|
|
|
if self.clamp_min_for_underflow: |
|
attn_weights = torch.clamp(attn_weights, min=-50000) |
|
if self.clamp_max_for_overflow: |
|
attn_weights = torch.clamp(attn_weights, max=50000) |
|
|
|
attn_weights_T = attn_weights.transpose(1, 2) |
|
attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[ |
|
0]) |
|
if self.clamp_min_for_underflow: |
|
attn_weights_l = torch.clamp(attn_weights_l, min=-50000) |
|
if self.clamp_max_for_overflow: |
|
attn_weights_l = torch.clamp(attn_weights_l, max=50000) |
|
|
|
attn_weights_l = attn_weights_l.softmax(dim=-1) |
|
|
|
if attention_mask_l is not None: |
|
assert (attention_mask_l.dim() == 2) |
|
attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) |
|
attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) |
|
attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15) |
|
|
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}" |
|
) |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
attn_weights_v = nn.functional.softmax(attn_weights, dim=-1) |
|
|
|
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training) |
|
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training) |
|
|
|
attn_output_v = torch.bmm(attn_probs_v, value_l_states) |
|
attn_output_l = torch.bmm(attn_probs_l, value_v_states) |
|
|
|
|
|
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}" |
|
) |
|
|
|
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}" |
|
) |
|
|
|
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim) |
|
attn_output_v = attn_output_v.transpose(1, 2) |
|
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) |
|
|
|
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim) |
|
attn_output_l = attn_output_l.transpose(1, 2) |
|
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) |
|
|
|
attn_output_v = self.out_v_proj(attn_output_v) |
|
attn_output_l = self.out_l_proj(attn_output_l) |
|
|
|
return attn_output_v, attn_output_l |
|
|
|
|
|
class BiAttentionBlockForCheckpoint(nn.Module): |
|
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, |
|
drop_path=.0, init_values=1e-4, ): |
|
""" |
|
Inputs: |
|
embed_dim - Dimensionality of input and attention feature vectors |
|
num_heads - Number of heads to use in the Multi-Head Attention block |
|
dropout - Amount of dropout to apply in the feed-forward network |
|
""" |
|
super(BiAttentionBlockForCheckpoint, self).__init__() |
|
|
|
|
|
self.layer_norm_v = nn.LayerNorm(v_dim) |
|
self.layer_norm_l = nn.LayerNorm(l_dim) |
|
self.attn = BiMultiHeadAttention(v_dim=v_dim, |
|
l_dim=l_dim, |
|
embed_dim=embed_dim, |
|
num_heads=num_heads, |
|
dropout=dropout, |
|
) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True) |
|
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True) |
|
|
|
|
|
def forward(self, v, l, attention_mask_l=None, task=None): |
|
|
|
|
|
v = self.layer_norm_v(v) |
|
l = self.layer_norm_l(l) |
|
delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l) |
|
|
|
v = v + self.drop_path(self.gamma_v * delta_v) |
|
l = l + self.drop_path(self.gamma_l * delta_l) |
|
return v, l |
|
|
|
|
|
|