Spaces:
Running
Running
import torch | |
import numpy as np | |
import pdb | |
debug_cnt = -1 | |
def make_batch(augmentor, difficulty = 0.3, train = True): | |
Hs = [] | |
img_list = augmentor.train if train else augmentor.test | |
dev = augmentor.device | |
batch_images = [] | |
with torch.no_grad(): # we dont require grads in the augmentation | |
for b in range(augmentor.batch_size): | |
rdidx = np.random.randint(len(img_list)) | |
img = torch.tensor(img_list[rdidx], dtype=torch.float32).permute(2,0,1).to(augmentor.device).unsqueeze(0) | |
batch_images.append(img) | |
batch_images = torch.cat(batch_images) | |
p1, H1 = augmentor(batch_images, difficulty) | |
p2, H2 = augmentor(batch_images, difficulty, TPS = True, prob_deformation = 0.7) | |
# p2, H2 = augmentor(batch_images, difficulty, TPS = False, prob_deformation = 0.7) | |
return p1, p2, H1, H2 | |
def plot_corrs(p1, p2, src_pts, tgt_pts): | |
import matplotlib.pyplot as plt | |
p1 = p1.cpu() | |
p2 = p2.cpu() | |
src_pts = src_pts.cpu() ; tgt_pts = tgt_pts.cpu() | |
rnd_idx = np.random.randint(len(src_pts), size=200) | |
src_pts = src_pts[rnd_idx, ...] | |
tgt_pts = tgt_pts[rnd_idx, ...] | |
#Plot ground-truth correspondences | |
fig, ax = plt.subplots(1,2,figsize=(18, 12)) | |
colors = np.random.uniform(size=(len(tgt_pts),3)) | |
#Src image | |
img = p1 | |
for i, p in enumerate(src_pts): | |
ax[0].scatter(p[0],p[1],color=colors[i]) | |
ax[0].imshow(img.permute(1,2,0).numpy()[...,::-1]) | |
#Target img | |
img2 = p2 | |
for i, p in enumerate(tgt_pts): | |
ax[1].scatter(p[0],p[1],color=colors[i]) | |
ax[1].imshow(img2.permute(1,2,0).numpy()[...,::-1]) | |
plt.show() | |
def get_corresponding_pts(p1, p2, H, H2, augmentor, h, w, crop = None): | |
''' | |
Get dense corresponding points | |
''' | |
global debug_cnt | |
negatives, positives = [], [] | |
with torch.no_grad(): | |
#real input res of samples | |
rh, rw = p1.shape[-2:] | |
ratio = torch.tensor([rw/w, rh/h], device = p1.device) | |
(H, mask1) = H | |
(H2, src, W, A, mask2) = H2 | |
#Generate meshgrid of target pts | |
x, y = torch.meshgrid(torch.arange(w, device=p1.device), torch.arange(h, device=p1.device), indexing ='xy') | |
mesh = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], dim=-1) | |
target_pts = mesh.view(-1, 2) * ratio | |
#Pack all transformations into T | |
for batch_idx in range(len(p1)): | |
with torch.no_grad(): | |
T = (H[batch_idx], H2[batch_idx], | |
src[batch_idx].unsqueeze(0), W[batch_idx].unsqueeze(0), A[batch_idx].unsqueeze(0)) | |
#We now warp the target points to src image | |
src_pts = (augmentor.get_correspondences(target_pts, T) ) #target to src | |
tgt_pts = (target_pts) | |
#Check out of bounds points | |
mask_valid = (src_pts[:, 0] >=0) & (src_pts[:, 1] >=0) & \ | |
(src_pts[:, 0] < rw) & (src_pts[:, 1] < rh) | |
negatives.append( tgt_pts[~mask_valid] ) | |
tgt_pts = tgt_pts[mask_valid] | |
src_pts = src_pts[mask_valid] | |
#Remove invalid pixels | |
mask_valid = mask1[batch_idx, src_pts[:,1].long(), src_pts[:,0].long()] & \ | |
mask2[batch_idx, tgt_pts[:,1].long(), tgt_pts[:,0].long()] | |
tgt_pts = tgt_pts[mask_valid] | |
src_pts = src_pts[mask_valid] | |
# limit nb of matches if desired | |
if crop is not None: | |
rnd_idx = torch.randperm(len(src_pts), device=src_pts.device)[:crop] | |
src_pts = src_pts[rnd_idx] | |
tgt_pts = tgt_pts[rnd_idx] | |
if debug_cnt >=0 and debug_cnt < 4: | |
plot_corrs(p1[batch_idx], p2[batch_idx], src_pts , tgt_pts ) | |
debug_cnt +=1 | |
src_pts = (src_pts / ratio) | |
tgt_pts = (tgt_pts / ratio) | |
#Check out of bounds points | |
padto = 10 if crop is not None else 2 | |
mask_valid1 = (src_pts[:, 0] >= (0 + padto)) & (src_pts[:, 1] >= (0 + padto)) & \ | |
(src_pts[:, 0] < (w - padto)) & (src_pts[:, 1] < (h - padto)) | |
mask_valid2 = (tgt_pts[:, 0] >= (0 + padto)) & (tgt_pts[:, 1] >= (0 + padto)) & \ | |
(tgt_pts[:, 0] < (w - padto)) & (tgt_pts[:, 1] < (h - padto)) | |
mask_valid = mask_valid1 & mask_valid2 | |
tgt_pts = tgt_pts[mask_valid] | |
src_pts = src_pts[mask_valid] | |
#Remove repeated correspondences | |
lut_mat = torch.ones((h, w, 4), device = src_pts.device, dtype = src_pts.dtype) * -1 | |
# src_pts_np = src_pts.cpu().numpy() | |
# tgt_pts_np = tgt_pts.cpu().numpy() | |
try: | |
lut_mat[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) | |
mask_valid = torch.all(lut_mat >= 0, dim=-1) | |
points = lut_mat[mask_valid] | |
positives.append(points) | |
except: | |
pdb.set_trace() | |
print('..') | |
return negatives, positives | |
def crop_patches(tensor, coords, size = 7): | |
''' | |
Crop [size x size] patches around 2D coordinates from a tensor. | |
''' | |
B, C, H, W = tensor.shape | |
x, y = coords[:, 0], coords[:, 1] | |
y = y.view(-1, 1, 1) | |
x = x.view(-1, 1, 1) | |
halfsize = size // 2 | |
# Create meshgrid for indexing | |
x_offset, y_offset = torch.meshgrid(torch.arange(-halfsize, halfsize+1), torch.arange(-halfsize, halfsize+1), indexing='xy') | |
y_offset = y_offset.to(tensor.device) | |
x_offset = x_offset.to(tensor.device) | |
# Compute indices around each coordinate | |
y_indices = (y + y_offset.view(1, size, size)).squeeze(0) + halfsize | |
x_indices = (x + x_offset.view(1, size, size)).squeeze(0) + halfsize | |
# Handle out-of-boundary indices with padding | |
tensor_padded = torch.nn.functional.pad(tensor, (halfsize, halfsize, halfsize, halfsize), mode='constant') | |
# Index tensor to get patches | |
patches = tensor_padded[:, :, y_indices, x_indices] # [B, C, N, H, W] | |
return patches | |
def subpix_softmax2d(heatmaps, temp = 0.25): | |
N, H, W = heatmaps.shape | |
heatmaps = torch.softmax(temp * heatmaps.view(-1, H*W), -1).view(-1, H, W) | |
x, y = torch.meshgrid(torch.arange(W, device = heatmaps.device ), torch.arange(H, device = heatmaps.device ), indexing = 'xy') | |
x = x - (W//2) | |
y = y - (H//2) | |
#pdb.set_trace() | |
coords_x = (x[None, ...] * heatmaps) | |
coords_y = (y[None, ...] * heatmaps) | |
coords = torch.cat([coords_x[..., None], coords_y[..., None]], -1).view(N, H*W, 2) | |
coords = coords.sum(1) | |
return coords | |