|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from cotracker.models.core.model_utils import sample_features4d, sample_features5d |
|
from cotracker.models.core.embeddings import ( |
|
get_2d_embedding, |
|
get_1d_sincos_pos_embed_from_grid, |
|
get_2d_sincos_pos_embed, |
|
) |
|
|
|
from cotracker.models.core.cotracker.blocks import ( |
|
Mlp, |
|
BasicEncoder, |
|
AttnBlock, |
|
CorrBlock, |
|
Attention, |
|
) |
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
class CoTracker2(nn.Module): |
|
def __init__( |
|
self, |
|
window_len=8, |
|
stride=4, |
|
add_space_attn=True, |
|
num_virtual_tracks=64, |
|
model_resolution=(384, 512), |
|
): |
|
super(CoTracker2, self).__init__() |
|
self.window_len = window_len |
|
self.stride = stride |
|
self.hidden_dim = 256 |
|
self.latent_dim = 128 |
|
self.add_space_attn = add_space_attn |
|
self.fnet = BasicEncoder(output_dim=self.latent_dim) |
|
self.num_virtual_tracks = num_virtual_tracks |
|
self.model_resolution = model_resolution |
|
self.input_dim = 456 |
|
self.updateformer = EfficientUpdateFormer( |
|
space_depth=6, |
|
time_depth=6, |
|
input_dim=self.input_dim, |
|
hidden_size=384, |
|
output_dim=self.latent_dim + 2, |
|
mlp_ratio=4.0, |
|
add_space_attn=add_space_attn, |
|
num_virtual_tracks=num_virtual_tracks, |
|
) |
|
|
|
time_grid = torch.linspace(0, window_len - 1, window_len).reshape(1, window_len, 1) |
|
|
|
self.register_buffer( |
|
"time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0]) |
|
) |
|
|
|
self.register_buffer( |
|
"pos_emb", |
|
get_2d_sincos_pos_embed( |
|
embed_dim=self.input_dim, |
|
grid_size=( |
|
model_resolution[0] // stride, |
|
model_resolution[1] // stride, |
|
), |
|
), |
|
) |
|
self.norm = nn.GroupNorm(1, self.latent_dim) |
|
self.track_feat_updater = nn.Sequential( |
|
nn.Linear(self.latent_dim, self.latent_dim), |
|
nn.GELU(), |
|
) |
|
self.vis_predictor = nn.Sequential( |
|
nn.Linear(self.latent_dim, 1), |
|
) |
|
|
|
def forward_window( |
|
self, |
|
fmaps, |
|
coords, |
|
track_feat=None, |
|
vis=None, |
|
track_mask=None, |
|
attention_mask=None, |
|
iters=4, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B, S_init, N, __ = track_mask.shape |
|
B, S, *_ = fmaps.shape |
|
|
|
track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant") |
|
track_mask_vis = ( |
|
torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2) |
|
) |
|
|
|
corr_block = CorrBlock( |
|
fmaps, |
|
num_levels=4, |
|
radius=3, |
|
padding_mode="border", |
|
) |
|
|
|
sampled_pos_emb = ( |
|
sample_features4d(self.pos_emb.repeat(B, 1, 1, 1), coords[:, 0]) |
|
.reshape(B * N, self.input_dim) |
|
.unsqueeze(1) |
|
) |
|
|
|
coord_preds = [] |
|
for __ in range(iters): |
|
coords = coords.detach() |
|
corr_block.corr(track_feat) |
|
|
|
|
|
fcorrs = corr_block.sample(coords) |
|
|
|
|
|
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) |
|
flow_emb = get_2d_embedding(flows, 64, cat_coords=True) |
|
|
|
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) |
|
|
|
transformer_input = torch.cat([flow_emb, fcorrs, track_feat_, track_mask_vis], dim=2) |
|
x = transformer_input + sampled_pos_emb + self.time_emb |
|
x = x.view(B, N, S, -1) |
|
|
|
delta = self.updateformer( |
|
x, |
|
attention_mask.reshape(B * S, N), |
|
) |
|
|
|
delta_coords = delta[..., :2].permute(0, 2, 1, 3) |
|
coords = coords + delta_coords |
|
coord_preds.append(coords * self.stride) |
|
|
|
delta_feats_ = delta[..., 2:].reshape(B * N * S, self.latent_dim) |
|
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim) |
|
track_feat_ = self.track_feat_updater(self.norm(delta_feats_)) + track_feat_ |
|
track_feat = track_feat_.reshape(B, N, S, self.latent_dim).permute( |
|
0, 2, 1, 3 |
|
) |
|
|
|
vis_pred = self.vis_predictor(track_feat).reshape(B, S, N) |
|
return coord_preds, vis_pred |
|
|
|
def get_track_feat(self, fmaps, queried_frames, queried_coords): |
|
sample_frames = queried_frames[:, None, :, None] |
|
sample_coords = torch.cat( |
|
[ |
|
sample_frames, |
|
queried_coords[:, None], |
|
], |
|
dim=-1, |
|
) |
|
sample_track_feats = sample_features5d(fmaps, sample_coords) |
|
return sample_track_feats |
|
|
|
def init_video_online_processing(self): |
|
self.online_ind = 0 |
|
self.online_track_feat = None |
|
self.online_coords_predicted = None |
|
self.online_vis_predicted = None |
|
|
|
def forward(self, video, queries, iters=4, is_train=False, is_online=False): |
|
"""Predict tracks |
|
|
|
Args: |
|
video (FloatTensor[B, T, 3]): input videos. |
|
queries (FloatTensor[B, N, 3]): point queries. |
|
iters (int, optional): number of updates. Defaults to 4. |
|
is_train (bool, optional): enables training mode. Defaults to False. |
|
is_online (bool, optional): enables online mode. Defaults to False. Before enabling, call model.init_video_online_processing(). |
|
|
|
Returns: |
|
- coords_predicted (FloatTensor[B, T, N, 2]): |
|
- vis_predicted (FloatTensor[B, T, N]): |
|
- train_data: `None` if `is_train` is false, otherwise: |
|
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]): |
|
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]): |
|
- mask (BoolTensor[B, T, N]): |
|
""" |
|
B, T, C, H, W = video.shape |
|
B, N, __ = queries.shape |
|
S = self.window_len |
|
device = queries.device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert S >= 2 |
|
if is_online: |
|
assert T <= S, "Online mode: video chunk must be <= window size." |
|
assert self.online_ind is not None, "Call model.init_video_online_processing() first." |
|
assert not is_train, "Training not supported in online mode." |
|
step = S // 2 |
|
video = 2 * (video / 255.0) - 1.0 |
|
|
|
|
|
|
|
queried_frames = queries[:, :, 0].long() |
|
|
|
queried_coords = queries[..., 1:] |
|
queried_coords = queried_coords / self.stride |
|
|
|
|
|
coords_predicted = torch.zeros((B, T, N, 2), device=device) |
|
vis_predicted = torch.zeros((B, T, N), device=device) |
|
if is_online: |
|
if self.online_coords_predicted is None: |
|
|
|
self.online_coords_predicted = coords_predicted |
|
self.online_vis_predicted = vis_predicted |
|
else: |
|
|
|
pad = min(step, T - step) |
|
coords_predicted = F.pad( |
|
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant" |
|
) |
|
vis_predicted = F.pad(self.online_vis_predicted, (0, 0, 0, pad), "constant") |
|
all_coords_predictions, all_vis_predictions = [], [] |
|
|
|
|
|
|
|
|
|
pad = S - T if is_online else (S - T % S) % S |
|
video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape( |
|
B, -1, C, H, W |
|
) |
|
|
|
|
|
fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape( |
|
B, -1, self.latent_dim, H // self.stride, W // self.stride |
|
) |
|
|
|
|
|
track_feat = self.get_track_feat( |
|
fmaps, |
|
queried_frames - self.online_ind if is_online else queried_frames, |
|
queried_coords, |
|
).repeat(1, S, 1, 1) |
|
if is_online: |
|
|
|
sample_frames = queried_frames[:, None, :, None] |
|
left = 0 if self.online_ind == 0 else self.online_ind + step |
|
right = self.online_ind + S |
|
sample_mask = (sample_frames >= left) & (sample_frames < right) |
|
if self.online_track_feat is None: |
|
self.online_track_feat = torch.zeros_like(track_feat, device=device) |
|
self.online_track_feat += track_feat * sample_mask |
|
track_feat = self.online_track_feat.clone() |
|
|
|
|
|
num_windows = (T - S + step - 1) // step + 1 |
|
|
|
indices = [self.online_ind] if is_online else range(0, step * num_windows, step) |
|
|
|
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float() |
|
vis_init = torch.ones((B, S, N, 1), device=device).float() * 10 |
|
for ind in indices: |
|
|
|
|
|
if ind > 0: |
|
overlap = S - step |
|
copy_over = (queried_frames < ind + overlap)[:, None, :, None] |
|
coords_prev = torch.nn.functional.pad( |
|
coords_predicted[:, ind : ind + overlap] / self.stride, |
|
(0, 0, 0, 0, 0, step), |
|
"replicate", |
|
) |
|
vis_prev = torch.nn.functional.pad( |
|
vis_predicted[:, ind : ind + overlap, :, None].clone(), |
|
(0, 0, 0, 0, 0, step), |
|
"replicate", |
|
) |
|
coords_init = torch.where( |
|
copy_over.expand_as(coords_init), coords_prev, coords_init |
|
) |
|
vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init) |
|
|
|
|
|
|
|
attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) |
|
|
|
|
|
|
|
|
|
track_mask = ( |
|
queried_frames[:, None, :, None] |
|
<= torch.arange(ind, ind + S, device=device)[None, :, None, None] |
|
).contiguous() |
|
|
|
if ind > 0: |
|
track_mask[:, :overlap, :, :] = False |
|
|
|
|
|
coords, vis = self.forward_window( |
|
fmaps=fmaps if is_online else fmaps[:, ind : ind + S], |
|
coords=coords_init, |
|
track_feat=attention_mask.unsqueeze(-1) * track_feat, |
|
vis=vis_init, |
|
track_mask=track_mask, |
|
attention_mask=attention_mask, |
|
iters=iters, |
|
) |
|
|
|
S_trimmed = T if is_online else min(T - ind, S) |
|
coords_predicted[:, ind : ind + S] = coords[-1][:, :S_trimmed] |
|
vis_predicted[:, ind : ind + S] = vis[:, :S_trimmed] |
|
if is_train: |
|
all_coords_predictions.append([coord[:, :S_trimmed] for coord in coords]) |
|
all_vis_predictions.append(torch.sigmoid(vis[:, :S_trimmed])) |
|
|
|
if is_online: |
|
self.online_ind += step |
|
self.online_coords_predicted = coords_predicted |
|
self.online_vis_predicted = vis_predicted |
|
vis_predicted = torch.sigmoid(vis_predicted) |
|
|
|
if is_train: |
|
mask = queried_frames[:, None] <= torch.arange(0, T, device=device)[None, :, None] |
|
train_data = (all_coords_predictions, all_vis_predictions, mask) |
|
else: |
|
train_data = None |
|
|
|
return coords_predicted, vis_predicted, train_data |
|
|
|
|
|
class EfficientUpdateFormer(nn.Module): |
|
""" |
|
Transformer model that updates track estimates. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
space_depth=6, |
|
time_depth=6, |
|
input_dim=320, |
|
hidden_size=384, |
|
num_heads=8, |
|
output_dim=130, |
|
mlp_ratio=4.0, |
|
add_space_attn=True, |
|
num_virtual_tracks=64, |
|
): |
|
super().__init__() |
|
self.out_channels = 2 |
|
self.num_heads = num_heads |
|
self.hidden_size = hidden_size |
|
self.add_space_attn = add_space_attn |
|
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) |
|
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) |
|
self.num_virtual_tracks = num_virtual_tracks |
|
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) |
|
self.time_blocks = nn.ModuleList( |
|
[ |
|
AttnBlock( |
|
hidden_size, |
|
num_heads, |
|
mlp_ratio=mlp_ratio, |
|
attn_class=Attention, |
|
) |
|
for _ in range(time_depth) |
|
] |
|
) |
|
|
|
if add_space_attn: |
|
self.space_virtual_blocks = nn.ModuleList( |
|
[ |
|
AttnBlock( |
|
hidden_size, |
|
num_heads, |
|
mlp_ratio=mlp_ratio, |
|
attn_class=Attention, |
|
) |
|
for _ in range(space_depth) |
|
] |
|
) |
|
self.space_point2virtual_blocks = nn.ModuleList( |
|
[ |
|
CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) |
|
for _ in range(space_depth) |
|
] |
|
) |
|
self.space_virtual2point_blocks = nn.ModuleList( |
|
[ |
|
CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) |
|
for _ in range(space_depth) |
|
] |
|
) |
|
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) |
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
def forward(self, input_tensor, mask=None): |
|
tokens = self.input_transform(input_tensor) |
|
B, _, T, _ = tokens.shape |
|
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) |
|
tokens = torch.cat([tokens, virtual_tokens], dim=1) |
|
_, N, _, _ = tokens.shape |
|
|
|
j = 0 |
|
for i in range(len(self.time_blocks)): |
|
time_tokens = tokens.contiguous().view(B * N, T, -1) |
|
time_tokens = self.time_blocks[i](time_tokens) |
|
|
|
tokens = time_tokens.view(B, N, T, -1) |
|
if self.add_space_attn and ( |
|
i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0 |
|
): |
|
space_tokens = ( |
|
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) |
|
) |
|
point_tokens = space_tokens[:, : N - self.num_virtual_tracks] |
|
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] |
|
|
|
virtual_tokens = self.space_virtual2point_blocks[j]( |
|
virtual_tokens, point_tokens, mask=mask |
|
) |
|
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) |
|
point_tokens = self.space_point2virtual_blocks[j]( |
|
point_tokens, virtual_tokens, mask=mask |
|
) |
|
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) |
|
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) |
|
j += 1 |
|
tokens = tokens[:, : N - self.num_virtual_tracks] |
|
flow = self.flow_head(tokens) |
|
return flow |
|
|
|
|
|
class CrossAttnBlock(nn.Module): |
|
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.norm_context = nn.LayerNorm(hidden_size) |
|
self.cross_attn = Attention( |
|
hidden_size, context_dim=context_dim, num_heads=num_heads, qkv_bias=True, **block_kwargs |
|
) |
|
|
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
mlp_hidden_dim = int(hidden_size * mlp_ratio) |
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.mlp = Mlp( |
|
in_features=hidden_size, |
|
hidden_features=mlp_hidden_dim, |
|
act_layer=approx_gelu, |
|
drop=0, |
|
) |
|
|
|
def forward(self, x, context, mask=None): |
|
if mask is not None: |
|
if mask.shape[1] == x.shape[1]: |
|
mask = mask[:, None, :, None].expand( |
|
-1, self.cross_attn.heads, -1, context.shape[1] |
|
) |
|
else: |
|
mask = mask[:, None, None].expand(-1, self.cross_attn.heads, x.shape[1], -1) |
|
|
|
max_neg_value = -torch.finfo(x.dtype).max |
|
attn_bias = (~mask) * max_neg_value |
|
x = x + self.cross_attn( |
|
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias |
|
) |
|
x = x + self.mlp(self.norm2(x)) |
|
return x |
|
|