Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Optional | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ripe import utils | |
from ripe.utils.utils import gridify | |
log = utils.get_pylogger(__name__) | |
class KeypointSampler(nn.Module): | |
""" | |
Sample keypoints according to a Heatmap | |
Adapted from: https://github.com/verlab/DALF_CVPR_2023/blob/main/modules/models/DALF.py | |
""" | |
def __init__(self, window_size=8): | |
super().__init__() | |
self.window_size = window_size | |
self.idx_cells = None # Cache for meshgrid indices | |
def sample(self, grid): | |
""" | |
Sample keypoints given a grid where each cell has logits stacked in last dimension | |
Input | |
grid: [B, C, H//w, W//w, w*w] | |
Returns | |
log_probs: [B, C, H//w, W//w ] - logprobs of selected samples | |
choices: [B, C, H//w, W//w] indices of choices | |
accept_mask: [B, C, H//w, W//w] mask of accepted keypoints | |
""" | |
chooser = torch.distributions.Categorical(logits=grid) | |
choices = chooser.sample() | |
logits_selected = torch.gather(grid, -1, choices.unsqueeze(-1)).squeeze(-1) | |
flipper = torch.distributions.Bernoulli(logits=logits_selected) | |
accepted_choices = flipper.sample() | |
# Sum log-probabilities is equivalent to multiplying the probabilities | |
log_probs = chooser.log_prob(choices) + flipper.log_prob(accepted_choices) | |
accept_mask = accepted_choices.gt(0) | |
return ( | |
log_probs.squeeze(1), | |
choices, | |
accept_mask.squeeze(1), | |
logits_selected.squeeze(1), | |
) | |
def precompute_idx_cells(self, H, W, device): | |
idx_cells = gridify( | |
torch.dstack( | |
torch.meshgrid( | |
torch.arange(H, dtype=torch.float32, device=device), | |
torch.arange(W, dtype=torch.float32, device=device), | |
) | |
) | |
.permute(2, 0, 1) | |
.unsqueeze(0) | |
.expand(1, -1, -1, -1), | |
window_size=self.window_size, | |
) | |
return idx_cells | |
def forward(self, x, mask_padding=None): | |
""" | |
Sample keypoints from a heatmap | |
Input | |
x: [B, C, H, W] Heatmap | |
mask_padding: [B, 1, H, W] Mask for padding (optional) | |
Returns | |
keypoints: [B, H//w, W//w, 2] Keypoints in (x, y) format | |
log_probs: [B, H//w, W//w] Log probabilities of selected keypoints | |
mask: [B, H//w, W//w] Mask of accepted keypoints | |
mask_padding: [B, 1, H//w, W//w] Mask of padding (optional) | |
logits_selected: [B, H//w, W//w] Logits of selected keypoints | |
""" | |
B, C, H, W = x.shape | |
keypoint_cells = gridify(x, self.window_size) | |
mask_padding = ( | |
(torch.min(gridify(mask_padding, self.window_size), dim=4).values) if mask_padding is not None else None | |
) | |
if self.idx_cells is None or self.idx_cells.shape[2:4] != ( | |
H // self.window_size, | |
W // self.window_size, | |
): | |
self.idx_cells = self.precompute_idx_cells(H, W, x.device) | |
log_probs, idx, mask, logits_selected = self.sample(keypoint_cells) | |
keypoints = ( | |
torch.gather( | |
self.idx_cells.expand(B, -1, -1, -1, -1), | |
-1, | |
idx.repeat(1, 2, 1, 1).unsqueeze(-1), | |
) | |
.squeeze(-1) | |
.permute(0, 2, 3, 1) | |
) | |
# flip keypoints to (x, y) format | |
return keypoints.flip(-1), log_probs, mask, mask_padding, logits_selected | |
class RIPE(nn.Module): | |
""" | |
Base class for extracting keypoints and descriptors | |
Input | |
x: [B, C, H, W] Images | |
Returns | |
kpts: | |
list of size [B] with detected keypoints | |
descs: | |
list of size [B] with descriptors | |
""" | |
def __init__( | |
self, | |
net, | |
upsampler, | |
window_size: int = 8, | |
non_linearity_dect=None, | |
desc_shares: Optional[List[int]] = None, | |
descriptor_dim: int = 256, | |
device=None, | |
): | |
super().__init__() | |
self.net = net | |
self.detector = KeypointSampler(window_size) | |
self.upsampler = upsampler | |
self.sampler = None | |
self.window_size = window_size | |
self.non_linearity_dect = non_linearity_dect if non_linearity_dect is not None else nn.Identity() | |
log.info(f"Training with window size {window_size}.") | |
log.info(f"Use {non_linearity_dect} as final non-linearity before the detection heatmap.") | |
dim_coarse_desc = self.get_dim_raw_desc() | |
if desc_shares is not None: | |
assert upsampler.name == "HyperColumnFeatures", ( | |
"Individual descriptor convolutions are only supported with HyperColumnFeatures" | |
) | |
assert len(desc_shares) == 4, "desc_shares should have 4 elements" | |
assert sum(desc_shares) == descriptor_dim, f"sum of desc_shares should be {descriptor_dim}" | |
self.conv_dim_reduction_coarse_desc = nn.ModuleList() | |
for dim_in, dim_out in zip(dim_coarse_desc, desc_shares): | |
log.info(f"Training dim reduction descriptor with {dim_in} -> {dim_out} 1x1 conv") | |
self.conv_dim_reduction_coarse_desc.append( | |
nn.Conv1d(dim_in, dim_out, kernel_size=1, stride=1, padding=0) | |
) | |
else: | |
if descriptor_dim is not None: | |
log.info(f"Training dim reduction descriptor with {sum(dim_coarse_desc)} -> {descriptor_dim} 1x1 conv") | |
self.conv_dim_reduction_coarse_desc = nn.Conv1d( | |
sum(dim_coarse_desc), | |
descriptor_dim, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
else: | |
log.warning( | |
f"No descriptor dimension specified, no 1x1 conv will be applied! Direct usage of {sum(dim_coarse_desc)}-dimensional raw descriptor" | |
) | |
self.conv_dim_reduction_coarse_desc = nn.Identity() | |
def get_dim_raw_desc(self): | |
layers_dims_encoder = self.net.get_dim_layers_encoder() | |
if self.upsampler.name == "InterpolateSparse2d": | |
return [layers_dims_encoder[-1]] | |
elif self.upsampler.name == "HyperColumnFeatures": | |
return layers_dims_encoder | |
else: | |
raise ValueError(f"Unknown interpolator {self.upsampler.name}") | |
def detectAndCompute(self, img, threshold=0.5, top_k=2048, output_aux=False): | |
self.train(False) | |
if img.dim() == 3: | |
img = img.unsqueeze(0) | |
out = self(img, training=False) | |
B, K, H, W = out["heatmap"].shape | |
assert B == 1, "Batch size should be 1" | |
kpts = [{"xy": self.NMS(out["heatmap"][b], threshold)} for b in range(B)] | |
if top_k is not None: | |
for b in range(B): | |
scores = out["heatmap"][b].squeeze(0)[kpts[b]["xy"][:, 1].long(), kpts[b]["xy"][:, 0].long()] | |
sorted_idx = torch.argsort(-scores) | |
kpts[b]["xy"] = kpts[b]["xy"][sorted_idx[:top_k]] | |
if "logprobs" in kpts[b]: | |
kpts[b]["logprobs"] = kpts[b]["xy"][sorted_idx[:top_k]] | |
if kpts[0]["xy"].shape[0] == 0: | |
raise RuntimeError("No keypoints detected") | |
# the following works for batch size 1 only | |
descs = self.get_descs(out["coarse_descs"], img, kpts[0]["xy"].unsqueeze(0), H, W) | |
descs = descs.squeeze(0) | |
score_map = out["heatmap"][0].squeeze(0) | |
kpts = kpts[0]["xy"] | |
scores = score_map[kpts[:, 1], kpts[:, 0]] | |
scores /= score_map.max() | |
sort_idx = torch.argsort(-scores) | |
kpts, descs, scores = kpts[sort_idx], descs[sort_idx], scores[sort_idx] | |
if output_aux: | |
return ( | |
kpts.float(), | |
descs, | |
scores, | |
{ | |
"heatmap": out["heatmap"], | |
"descs": out["coarse_descs"], | |
"conv": self.conv_dim_reduction_coarse_desc, | |
}, | |
) | |
return kpts.float(), descs, scores | |
def NMS(self, x, threshold=3.0, kernel_size=3): | |
pad = kernel_size // 2 | |
local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x) | |
pos = (x == local_max) & (x > threshold) | |
return pos.nonzero()[..., 1:].flip(-1) | |
def get_descs(self, feature_map, guidance, kpts, H, W): | |
descs = self.upsampler(feature_map, kpts, H, W) | |
if isinstance(self.conv_dim_reduction_coarse_desc, nn.ModuleList): | |
# individual descriptor convolutions for each layer | |
desc_conv = [] | |
for desc, conv in zip(descs, self.conv_dim_reduction_coarse_desc): | |
desc_conv.append(conv(desc.permute(0, 2, 1)).permute(0, 2, 1)) | |
desc = torch.cat(desc_conv, dim=-1) | |
else: | |
desc = torch.cat(descs, dim=-1) | |
desc = self.conv_dim_reduction_coarse_desc(desc.permute(0, 2, 1)).permute(0, 2, 1) | |
desc = F.normalize(desc, dim=2) | |
return desc | |
def forward(self, x, mask_padding=None, training=False): | |
B, C, H, W = x.shape | |
out = self.net(x) | |
out["heatmap"] = self.non_linearity_dect(out["heatmap"]) | |
# print(out['map'].shape, out['descr'].shape) | |
if training: | |
kpts, log_probs, mask, mask_padding, logits_selected = self.detector(out["heatmap"], mask_padding) | |
filter_A = kpts[:, :, :, 0] >= 16 | |
filter_B = kpts[:, :, :, 1] >= 16 | |
filter_C = kpts[:, :, :, 0] < W - 16 | |
filter_D = kpts[:, :, :, 1] < H - 16 | |
filter_all = filter_A * filter_B * filter_C * filter_D | |
mask = mask * filter_all | |
return ( | |
kpts.view(B, -1, 2), | |
log_probs.view(B, -1), | |
mask.view(B, -1), | |
mask_padding.view(B, -1), | |
logits_selected.view(B, -1), | |
out, | |
) | |
else: | |
return out | |
def output_number_trainable_params(model): | |
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
nb_params = sum([np.prod(p.size()) for p in model_parameters]) | |
print(f"Number of trainable parameters: {nb_params:d}") | |