|
|
|
|
|
|
|
|
|
from pdb import set_trace as bb |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from core import functional as myF |
|
|
|
|
|
class DeepMatchingLoss (nn.Module): |
|
""" This loss is based on DeepMatching (IJCV'16). |
|
atleast: (int) minimum image size at which the pyramid construction stops. |
|
sub: (int) prior subsampling |
|
way: (str) which way to compute the asymmetric matching ('1', '2' or '12') |
|
border: (int) ignore pixels too close to the border |
|
rectify_p: (float) non-linear power-rectification in DeepMatching |
|
eps: (float) epsilon for the L1 normalization. Kinda handles unmatched pixels. |
|
""" |
|
def __init__(self, eps=0.03, atleast=5, sub=2, way='12', border=16, rectify_p=1.5): |
|
super().__init__() |
|
assert way in ('1','2','12') |
|
self.subsample = sub |
|
self.border = border |
|
self.way = way |
|
self.atleast = atleast |
|
self.rectify_p = rectify_p |
|
self.eps = eps |
|
|
|
self._cache = {} |
|
|
|
def rectify(self, corr): |
|
corr = corr.clip_(min=0) |
|
corr = corr ** self.rectify_p |
|
return corr |
|
|
|
def forward(self, desc1, desc2, **kw): |
|
|
|
loss1 = self.forward_oneway(desc1, desc2, **kw) \ |
|
if '1' in self.way else 0 |
|
|
|
|
|
loss2 = self.forward_oneway(desc2, desc1, **kw) \ |
|
if '2' in self.way else 0 |
|
|
|
return dict(deepm_loss=(loss1+loss2)/len(self.way)) |
|
|
|
def forward_oneway(self, desc1, desc2, dbg=(), **kw): |
|
assert desc1.shape[:2] == desc2.shape[:2] |
|
|
|
|
|
s = slice(self.border, -self.border or None, self.subsample) |
|
desc1, desc2 = desc1[...,s,s], desc2[...,s,s] |
|
desc1 = desc1[:,:,2::4,2::4] |
|
B, D, H1, W1, H2, W2 = desc1.shape + desc2.shape[-2:] |
|
if B == 0: return 0 |
|
|
|
|
|
corr = torch.bmm(desc1.reshape(B,D,-1).transpose(1,2), desc2.reshape(B,D,-1)).view(B,H1,W1,H2,W2) |
|
|
|
|
|
pyramid = self.deep_matching(corr) |
|
corr = pyramid[-1] |
|
corr = self.rectify(corr) |
|
|
|
|
|
B, H1, W1, H2, W2 = corr.shape |
|
corr = corr / (corr.reshape(B,H1*W1,-1).sum(dim=-1).view(B,H1,W1,1,1) + self.eps) |
|
|
|
|
|
loss = - torch.square(corr).sum() / (B*H1*W1) |
|
return loss |
|
|
|
def deep_matching(self, corr): |
|
|
|
weights = None |
|
pyramid = [corr] |
|
for level in range(1,999): |
|
corr, weights = self.forward_level(level, corr, weights) |
|
pyramid.append(corr) |
|
|
|
if weights.sum() == 0: break |
|
if min(corr.shape[-2:]) < 2*self.atleast: break |
|
return pyramid |
|
|
|
def forward_level(self, level, corr, weights): |
|
B, H1, W1, H2, W2 = corr.shape |
|
|
|
|
|
pooled = F.max_pool2d(corr.view(B,H1*W1,H2,W2), 3, padding=1, stride=2) |
|
pooled = pooled.view(B, H1, W1, *pooled.shape[-2:]) |
|
|
|
|
|
pooled = self.rectify(pooled) |
|
|
|
|
|
key = level, H1, W1, H2, W2 |
|
if key not in self._cache: |
|
B, H1, W1, H2, W2 = myF.true_corr_shape(pooled.shape, level-1) |
|
self._cache[key] = myF.children(level, H1, W1, H2, W2).to(corr.device) |
|
|
|
return sparse_conv(level, pooled, self._cache[key], weights) |
|
|
|
|
|
def sparse_conv(level, corr, parents, weights=None, border_norm=0.9): |
|
B, H1, W1, H2, W2 = myF.true_corr_shape(corr.shape, level-1) |
|
n_cache = len(parents) |
|
|
|
|
|
|
|
corr = corr.view(B, -1, H2, W2) |
|
|
|
res = corr.new_zeros((B, n_cache+1, H2, W2)) |
|
nrm = corr.new_full((n_cache+1, 3, 3), torch.finfo(corr.dtype).eps) |
|
ones = nrm.new_ones((corr.shape[1], 1, 1)) |
|
ex = 1 |
|
if weights is not None: |
|
weights = weights.view(corr.shape[1],1,1) |
|
corr = corr * weights[None] |
|
ones *= weights |
|
|
|
sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None) |
|
c = 0 |
|
for y in (-1, 1): |
|
for x in (-1, 1): |
|
src_layers = parents[:,c]; c+= 1 |
|
|
|
|
|
tgt_layers = myF.inverse_mapping(src_layers, max_elem=corr.shape[1], default=n_cache)[:-1] |
|
|
|
|
|
|
|
sel = myF.good_slice( tgt_layers < n_cache ) |
|
|
|
res[:,:,sl(-y),sl(-x)].index_add_(1, tgt_layers[sel], corr[:,sel,sl(y),sl(x)]) |
|
nrm[ :,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2)) |
|
|
|
|
|
weights = myF.norm_borders(res, nrm, norm=border_norm)[:-1] |
|
|
|
res = res[:,:-1] |
|
return res.view(B, H1+ex, W1+ex, *res.shape[-2:]), weights |
|
|
|
|