Realcat's picture
add: liftfeat
13760e8
import torch
import numpy as np
import pdb
debug_cnt = -1
def make_batch(augmentor, difficulty = 0.3, train = True):
Hs = []
img_list = augmentor.train if train else augmentor.test
dev = augmentor.device
batch_images = []
with torch.no_grad(): # we dont require grads in the augmentation
for b in range(augmentor.batch_size):
rdidx = np.random.randint(len(img_list))
img = torch.tensor(img_list[rdidx], dtype=torch.float32).permute(2,0,1).to(augmentor.device).unsqueeze(0)
batch_images.append(img)
batch_images = torch.cat(batch_images)
p1, H1 = augmentor(batch_images, difficulty)
p2, H2 = augmentor(batch_images, difficulty, TPS = True, prob_deformation = 0.7)
# p2, H2 = augmentor(batch_images, difficulty, TPS = False, prob_deformation = 0.7)
return p1, p2, H1, H2
def plot_corrs(p1, p2, src_pts, tgt_pts):
import matplotlib.pyplot as plt
p1 = p1.cpu()
p2 = p2.cpu()
src_pts = src_pts.cpu() ; tgt_pts = tgt_pts.cpu()
rnd_idx = np.random.randint(len(src_pts), size=200)
src_pts = src_pts[rnd_idx, ...]
tgt_pts = tgt_pts[rnd_idx, ...]
#Plot ground-truth correspondences
fig, ax = plt.subplots(1,2,figsize=(18, 12))
colors = np.random.uniform(size=(len(tgt_pts),3))
#Src image
img = p1
for i, p in enumerate(src_pts):
ax[0].scatter(p[0],p[1],color=colors[i])
ax[0].imshow(img.permute(1,2,0).numpy()[...,::-1])
#Target img
img2 = p2
for i, p in enumerate(tgt_pts):
ax[1].scatter(p[0],p[1],color=colors[i])
ax[1].imshow(img2.permute(1,2,0).numpy()[...,::-1])
plt.show()
def get_corresponding_pts(p1, p2, H, H2, augmentor, h, w, crop = None):
'''
Get dense corresponding points
'''
global debug_cnt
negatives, positives = [], []
with torch.no_grad():
#real input res of samples
rh, rw = p1.shape[-2:]
ratio = torch.tensor([rw/w, rh/h], device = p1.device)
(H, mask1) = H
(H2, src, W, A, mask2) = H2
#Generate meshgrid of target pts
x, y = torch.meshgrid(torch.arange(w, device=p1.device), torch.arange(h, device=p1.device), indexing ='xy')
mesh = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], dim=-1)
target_pts = mesh.view(-1, 2) * ratio
#Pack all transformations into T
for batch_idx in range(len(p1)):
with torch.no_grad():
T = (H[batch_idx], H2[batch_idx],
src[batch_idx].unsqueeze(0), W[batch_idx].unsqueeze(0), A[batch_idx].unsqueeze(0))
#We now warp the target points to src image
src_pts = (augmentor.get_correspondences(target_pts, T) ) #target to src
tgt_pts = (target_pts)
#Check out of bounds points
mask_valid = (src_pts[:, 0] >=0) & (src_pts[:, 1] >=0) & \
(src_pts[:, 0] < rw) & (src_pts[:, 1] < rh)
negatives.append( tgt_pts[~mask_valid] )
tgt_pts = tgt_pts[mask_valid]
src_pts = src_pts[mask_valid]
#Remove invalid pixels
mask_valid = mask1[batch_idx, src_pts[:,1].long(), src_pts[:,0].long()] & \
mask2[batch_idx, tgt_pts[:,1].long(), tgt_pts[:,0].long()]
tgt_pts = tgt_pts[mask_valid]
src_pts = src_pts[mask_valid]
# limit nb of matches if desired
if crop is not None:
rnd_idx = torch.randperm(len(src_pts), device=src_pts.device)[:crop]
src_pts = src_pts[rnd_idx]
tgt_pts = tgt_pts[rnd_idx]
if debug_cnt >=0 and debug_cnt < 4:
plot_corrs(p1[batch_idx], p2[batch_idx], src_pts , tgt_pts )
debug_cnt +=1
src_pts = (src_pts / ratio)
tgt_pts = (tgt_pts / ratio)
#Check out of bounds points
padto = 10 if crop is not None else 2
mask_valid1 = (src_pts[:, 0] >= (0 + padto)) & (src_pts[:, 1] >= (0 + padto)) & \
(src_pts[:, 0] < (w - padto)) & (src_pts[:, 1] < (h - padto))
mask_valid2 = (tgt_pts[:, 0] >= (0 + padto)) & (tgt_pts[:, 1] >= (0 + padto)) & \
(tgt_pts[:, 0] < (w - padto)) & (tgt_pts[:, 1] < (h - padto))
mask_valid = mask_valid1 & mask_valid2
tgt_pts = tgt_pts[mask_valid]
src_pts = src_pts[mask_valid]
#Remove repeated correspondences
lut_mat = torch.ones((h, w, 4), device = src_pts.device, dtype = src_pts.dtype) * -1
# src_pts_np = src_pts.cpu().numpy()
# tgt_pts_np = tgt_pts.cpu().numpy()
try:
lut_mat[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
mask_valid = torch.all(lut_mat >= 0, dim=-1)
points = lut_mat[mask_valid]
positives.append(points)
except:
pdb.set_trace()
print('..')
return negatives, positives
def crop_patches(tensor, coords, size = 7):
'''
Crop [size x size] patches around 2D coordinates from a tensor.
'''
B, C, H, W = tensor.shape
x, y = coords[:, 0], coords[:, 1]
y = y.view(-1, 1, 1)
x = x.view(-1, 1, 1)
halfsize = size // 2
# Create meshgrid for indexing
x_offset, y_offset = torch.meshgrid(torch.arange(-halfsize, halfsize+1), torch.arange(-halfsize, halfsize+1), indexing='xy')
y_offset = y_offset.to(tensor.device)
x_offset = x_offset.to(tensor.device)
# Compute indices around each coordinate
y_indices = (y + y_offset.view(1, size, size)).squeeze(0) + halfsize
x_indices = (x + x_offset.view(1, size, size)).squeeze(0) + halfsize
# Handle out-of-boundary indices with padding
tensor_padded = torch.nn.functional.pad(tensor, (halfsize, halfsize, halfsize, halfsize), mode='constant')
# Index tensor to get patches
patches = tensor_padded[:, :, y_indices, x_indices] # [B, C, N, H, W]
return patches
def subpix_softmax2d(heatmaps, temp = 0.25):
N, H, W = heatmaps.shape
heatmaps = torch.softmax(temp * heatmaps.view(-1, H*W), -1).view(-1, H, W)
x, y = torch.meshgrid(torch.arange(W, device = heatmaps.device ), torch.arange(H, device = heatmaps.device ), indexing = 'xy')
x = x - (W//2)
y = y - (H//2)
#pdb.set_trace()
coords_x = (x[None, ...] * heatmaps)
coords_y = (y[None, ...] * heatmaps)
coords = torch.cat([coords_x[..., None], coords_y[..., None]], -1).view(N, H*W, 2)
coords = coords.sum(1)
return coords