|
import torch |
|
import copy |
|
from torch import nn, Tensor |
|
import os |
|
|
|
import math |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
|
|
class MLP(nn.Module): |
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
return x |
|
|
|
|
|
def inverse_sigmoid(x, eps=1e-5): |
|
x = x.clamp(min=0, max=1) |
|
x1 = x.clamp(min=eps) |
|
x2 = (1 - x).clamp(min=eps) |
|
return torch.log(x1/x2) |
|
|
|
|
|
def gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor): |
|
""" |
|
Input: |
|
- memory: bs, \sum{hw}, d_model |
|
- memory_padding_mask: bs, \sum{hw} |
|
- spatial_shapes: nlevel, 2 |
|
Output: |
|
- output_memory: bs, \sum{hw}, d_model |
|
- output_proposals: bs, \sum{hw}, 4 |
|
""" |
|
N_, S_, C_ = memory.shape |
|
base_scale = 4.0 |
|
proposals = [] |
|
_cur = 0 |
|
for lvl, (H_, W_) in enumerate(spatial_shapes): |
|
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) |
|
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) |
|
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) |
|
|
|
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), |
|
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) |
|
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) |
|
|
|
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) |
|
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale |
|
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) |
|
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) |
|
proposals.append(proposal) |
|
_cur += (H_ * W_) |
|
output_proposals = torch.cat(proposals, 1) |
|
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) |
|
output_proposals = torch.log(output_proposals / (1 - output_proposals)) |
|
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) |
|
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) |
|
|
|
output_memory = memory |
|
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) |
|
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) |
|
return output_memory, output_proposals |
|
|
|
|
|
def gen_sineembed_for_position(pos_tensor): |
|
|
|
|
|
scale = 2 * math.pi |
|
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) |
|
dim_t = 10000 ** (2 * (dim_t // 2) / 128) |
|
x_embed = pos_tensor[:, :, 0] * scale |
|
y_embed = pos_tensor[:, :, 1] * scale |
|
pos_x = x_embed[:, :, None] / dim_t |
|
pos_y = y_embed[:, :, None] / dim_t |
|
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) |
|
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) |
|
if pos_tensor.size(-1) == 2: |
|
pos = torch.cat((pos_y, pos_x), dim=2) |
|
elif pos_tensor.size(-1) == 4: |
|
w_embed = pos_tensor[:, :, 2] * scale |
|
pos_w = w_embed[:, :, None] / dim_t |
|
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
|
h_embed = pos_tensor[:, :, 3] * scale |
|
pos_h = h_embed[:, :, None] / dim_t |
|
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
|
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) |
|
else: |
|
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) |
|
return pos |
|
|
|
|
|
def _get_activation_fn(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return F.relu |
|
if activation == "gelu": |
|
return F.gelu |
|
if activation == "glu": |
|
return F.glu |
|
if activation == "prelu": |
|
return nn.PReLU() |
|
if activation == "selu": |
|
return F.selu |
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
def _get_clones(module, N, layer_share=False): |
|
|
|
if layer_share: |
|
return nn.ModuleList([module for i in range(N)]) |
|
else: |
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
def _get_clones_advanced(module, N, N_valid): |
|
assert N_valid <= N |
|
layers = [] |
|
for i in range(N): |
|
if i < N_valid: |
|
layers.append(copy.deepcopy(module)) |
|
else: |
|
layers.append(nn.Identity()) |
|
return nn.ModuleList(layers) |