Spaces:
Running
Running
""" | |
"LiftFeat: 3D Geometry-Aware Local Feature Matching" | |
COCO_20k image augmentor | |
""" | |
import torch | |
from torch import nn | |
from torch.utils.data import Dataset | |
import torch.utils.data as data | |
from torchvision import transforms | |
import torch.nn.functional as F | |
import cv2 | |
import kornia | |
import kornia.augmentation as K | |
from kornia.geometry.transform import get_tps_transform as findTPS | |
from kornia.geometry.transform import warp_points_tps, warp_image_tps | |
import glob | |
import random | |
import tqdm | |
import numpy as np | |
import pdb | |
import time | |
random.seed(0) | |
torch.manual_seed(0) | |
def generateRandomTPS(shape,grid=(8,6),GLOBAL_MULTIPLIER=0.3,prob=0.5): | |
h, w = shape | |
sh, sw = h/grid[0], w/grid[1] | |
src = torch.dstack(torch.meshgrid(torch.arange(0, h + sh , sh), torch.arange(0, w + sw , sw), indexing='ij')) | |
offsets = torch.rand(grid[0]+1, grid[1]+1, 2) - 0.5 | |
offsets *= torch.tensor([ sh/2, sw/2 ]).view(1, 1, 2) * min(0.97, 2.0 * GLOBAL_MULTIPLIER) | |
dst = src + offsets if np.random.uniform() < prob else src | |
src, dst = src.view(1, -1, 2), dst.view(1, -1, 2) | |
src = (src / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1. | |
dst = (dst / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1. | |
weights, A = findTPS(dst, src) | |
return src, weights, A | |
def generateRandomHomography(shape,GLOBAL_MULTIPLIER=0.3): | |
#Generate random in-plane rotation [-theta,+theta] | |
theta = np.radians(np.random.uniform(-30, 30)) | |
#Generate random scale in both x and y | |
scale_x, scale_y = np.random.uniform(0.35, 1.2, 2) | |
#Generate random translation shift | |
tx , ty = -shape[1]/2.0 , -shape[0]/2.0 | |
txn, tyn = np.random.normal(0, 120.0*GLOBAL_MULTIPLIER, 2) | |
c, s = np.cos(theta), np.sin(theta) | |
#Affine coeffs | |
sx , sy = np.random.normal(0,0.6*GLOBAL_MULTIPLIER,2) | |
#Projective coeffs | |
p1 , p2 = np.random.normal(0,0.006*GLOBAL_MULTIPLIER,2) | |
# Build Homography from parmeterizations | |
H_t = np.array(((1,0, tx), (0, 1, ty), (0,0,1))) #t | |
H_r = np.array(((c,-s, 0), (s, c, 0), (0,0,1))) #rotation, | |
H_a = np.array(((1,sy, 0), (sx, 1, 0), (0,0,1))) # affine | |
H_p = np.array(((1, 0, 0), (0 , 1, 0), (p1,p2,1))) # projective | |
H_s = np.array(((scale_x,0, 0), (0, scale_y, 0), (0,0,1))) #scale | |
H_b = np.array(((1.0,0,-tx +txn), (0, 1, -ty + tyn), (0,0,1))) #t_back, | |
#H = H_e * H_s * H_a * H_p | |
H = np.dot(np.dot(np.dot(np.dot(np.dot(H_b,H_s),H_p),H_a),H_r),H_t) | |
return H | |
class COCOAugmentor(nn.Module): | |
def __init__(self,device,load_dataset=True, | |
img_dir="/home/yepeng_liu/code_python/dataset/coco_20k", | |
warp_resolution=(1200, 900), | |
out_resolution=(400, 300), | |
sides_crop=0.2, | |
max_num_imgs=50, | |
num_test_imgs=10, | |
batch_size=1, | |
photometric=True, | |
geometric=True, | |
reload_step=1_000 | |
): | |
super(COCOAugmentor,self).__init__() | |
self.half=16 | |
self.device=device | |
self.dims=warp_resolution | |
self.batch_size=batch_size | |
self.out_resolution=out_resolution | |
self.sides_crop=sides_crop | |
self.max_num_imgs=max_num_imgs | |
self.num_test_imgs=num_test_imgs | |
self.dims_t=torch.tensor([int(self.dims[0]*(1. - self.sides_crop)) - int(self.dims[0]*self.sides_crop) -1, | |
int(self.dims[1]*(1. - self.sides_crop)) - int(self.dims[1]*self.sides_crop) -1]).float().to(device).view(1,1,2) | |
self.dims_s=torch.tensor([self.dims_t[0,0,0] / out_resolution[0], | |
self.dims_t[0,0,1] / out_resolution[1]]).float().to(device).view(1,1,2) | |
self.all_imgs=glob.glob(img_dir+'/*.jpg')+glob.glob(img_dir+'/*.png') | |
self.photometric=photometric | |
self.geometric=geometric | |
self.cnt=1 | |
self.reload_step=reload_step | |
list_augmentation=[ | |
kornia.augmentation.ColorJitter(0.15,0.15,0.15,0.15,p=1.), | |
kornia.augmentation.RandomEqualize(p=0.4), | |
kornia.augmentation.RandomGaussianBlur(p=0.3,sigma=(2.0,2.0),kernel_size=(7,7)) | |
] | |
if photometric is False: | |
list_augmentation = [] | |
self.aug_list=kornia.augmentation.ImageSequential(*list_augmentation) | |
if len(self.all_imgs)<10: | |
raise RuntimeError('Couldnt find enough images to train. Please check the path: ',img_dir) | |
if load_dataset: | |
print('[COCO]: ',len(self.all_imgs),' images for training..') | |
if len(self.all_imgs) - num_test_imgs < max_num_imgs: | |
raise RuntimeError('Error: test set overlaps with training set! Decrease number of test imgs') | |
self.load_imgs() | |
self.TPS = True | |
def load_imgs(self): | |
random.shuffle(self.all_imgs) | |
train = [] | |
for p in tqdm.tqdm(self.all_imgs[:self.max_num_imgs],desc='loading train'): | |
im=cv2.imread(p) | |
halfH,halfW=im.shape[0]//2,im.shape[1]//2 | |
if halfH>halfW: | |
im=np.rot90(im) | |
halfH,halfW=halfW,halfH | |
if im.shape[0]!=self.dims[1] or im.shape[1]!=self.dims[0]: | |
im = cv2.resize(im, self.dims) | |
train.append(np.copy(im)) | |
self.train=train | |
self.test=[ | |
cv2.resize(cv2.imread(p),self.dims) | |
for p in tqdm.tqdm(self.all_imgs[-self.num_test_imgs:],desc='loading test') | |
] | |
def norm_pts_grid(self, x): | |
if len(x.size()) == 2: | |
return (x.view(1,-1,2) * self.dims_s / self.dims_t) * 2. - 1 | |
return (x * self.dims_s / self.dims_t) * 2. - 1 | |
def denorm_pts_grid(self, x): | |
if len(x.size()) == 2: | |
return ((x.view(1,-1,2) + 1) / 2.) / self.dims_s * self.dims_t | |
return ((x+1) / 2.) / self.dims_s * self.dims_t | |
def rnd_kps(self, shape, n = 256): | |
h, w = shape | |
kps = torch.rand(size = (3,n)).to(self.device) | |
kps[0,:]*=w | |
kps[1,:]*=h | |
kps[2,:] = 1.0 | |
return kps | |
def warp_points(self, H, pts): | |
scale = self.dims_s.view(-1,2) | |
offset = torch.tensor([int(self.dims[0]*self.sides_crop), int(self.dims[1]*self.sides_crop)], device = pts.device).float() | |
pts = pts*scale + offset | |
pts = torch.vstack( [pts.t(), torch.ones(1, pts.shape[0], device = pts.device)]) | |
warped = torch.matmul(H, pts) | |
warped = warped / warped[2,...] | |
warped = warped.t()[:, :2] | |
return (warped - offset) / scale | |
def forward(self, x, difficulty = 0.3, TPS = False, prob_deformation = 0.5, test = False): | |
""" | |
Perform augmentation to a batch of images. | |
input: | |
x -> torch.Tensor(B, C, H, W): rgb images | |
difficulty -> float: level of difficulty, 0.1 is medium, 0.3 is already pretty hard | |
tps -> bool: Wether to apply non-rigid deformations in images | |
prob_deformation -> float: probability to apply a deformation | |
return: | |
'output' -> torch.Tensor(B, C, H, W): rgb images | |
Tuple: | |
'H' -> torch.Tensor(3,3): homography matrix | |
'mask' -> torch.Tensor(B, H, W): mask of valid pixels after warp | |
(deformation only) | |
src, weights, A are parameters from a TPS warp (all torch.Tensors) | |
""" | |
if self.cnt % self.reload_step == 0: | |
self.load_imgs() | |
if self.geometric is False: | |
difficulty = 0. | |
with torch.no_grad(): | |
x = (x/255.).to(self.device) | |
b, c, h, w = x.shape | |
shape = (h, w) | |
######## Geometric Transformations | |
H = torch.tensor(np.array([generateRandomHomography(shape,difficulty) for b in range(self.batch_size)]),dtype=torch.float32).to(self.device) | |
output = kornia.geometry.transform.warp_perspective(x,H,dsize=shape,padding_mode='zeros') | |
#crop % of image boundaries each side to reduce invalid pixels after warps | |
low_h = int(h * self.sides_crop); low_w = int(w*self.sides_crop) | |
high_h = int(h*(1. - self.sides_crop)); high_w= int(w * (1. - self.sides_crop)) | |
output = output[..., low_h:high_h, low_w:high_w] | |
x = x[..., low_h:high_h, low_w:high_w] | |
#apply TPS if desired: | |
if TPS: | |
src, weights, A = None, None, None | |
for b in range(self.batch_size): | |
b_src, b_weights, b_A = generateRandomTPS(shape, (8,6), difficulty, prob = prob_deformation) | |
b_src, b_weights, b_A = b_src.to(self.device), b_weights.to(self.device), b_A.to(self.device) | |
if src is None: | |
src, weights, A = b_src, b_weights, b_A | |
else: | |
src = torch.cat((b_src, src)) | |
weights = torch.cat((b_weights, weights)) | |
A = torch.cat((b_A, A)) | |
output = warp_image_tps(output, src, weights, A) | |
output = F.interpolate(output, self.out_resolution[::-1], mode = 'nearest') | |
x = F.interpolate(x, self.out_resolution[::-1], mode = 'nearest') | |
mask = ~torch.all(output == 0, dim=1, keepdim=True) | |
mask = mask.expand(-1,3,-1,-1) | |
# Make-up invalid regions with texture from the batch | |
rv = 1 if not TPS else 2 | |
output_shifted = torch.roll(x, rv, 0) | |
output[~mask] = output_shifted[~mask] | |
mask = mask[:, 0, :, :] | |
######## Photometric Transformations | |
output = self.aug_list(output) | |
b, c, h, w = output.shape | |
#Correlated Gaussian Noise | |
if np.random.uniform() > 0.5 and self.photometric: | |
noise = F.interpolate(torch.randn_like(output)*(10/255), (h//2, w//2)) | |
noise = F.interpolate(noise, (h, w), mode = 'bicubic') | |
output = torch.clip( output + noise, 0., 1.) | |
#Random shadows | |
if np.random.uniform() > 0.6 and self.photometric: | |
noise = torch.rand((b, 1, h//64, w//64), device = self.device) * 1.3 | |
noise = torch.clip(noise, 0.25, 1.0) | |
noise = F.interpolate(noise, (h, w), mode = 'bicubic') | |
noise = noise.expand(-1, 3, -1, -1) | |
output *= noise | |
output = torch.clip( output, 0., 1.) | |
self.cnt+=1 | |
if TPS: | |
return output, (H, src, weights, A, mask) | |
else: | |
return output, (H, mask) | |
def get_correspondences(self, kps_target, T): | |
H, H2, src, W, A = T | |
undeformed = self.denorm_pts_grid( | |
warp_points_tps(self.norm_pts_grid(kps_target), | |
src, W, A) ).view(-1,2) | |
warped_to_src = self.warp_points([email protected](H2), undeformed) | |
return warped_to_src |