Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
18 kB
"""
Loss function implementations.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.geometry import warp_perspective
from ..misc.geometry_utils import (keypoints_to_grid, get_dist_mask,
get_common_line_mask)
def get_loss_and_weights(model_cfg, device=torch.device("cuda")):
""" Get loss functions and either static or dynamic weighting. """
# Get the global weighting policy
w_policy = model_cfg.get("weighting_policy", "static")
if not w_policy in ["static", "dynamic"]:
raise ValueError("[Error] Not supported weighting policy.")
loss_func = {}
loss_weight = {}
# Get junction loss function and weight
w_junc, junc_loss_func = get_junction_loss_and_weight(model_cfg, w_policy)
loss_func["junc_loss"] = junc_loss_func.to(device)
loss_weight["w_junc"] = w_junc
# Get heatmap loss function and weight
w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight(
model_cfg, w_policy, device)
loss_func["heatmap_loss"] = heatmap_loss_func.to(device)
loss_weight["w_heatmap"] = w_heatmap
# [Optionally] get descriptor loss function and weight
if model_cfg.get("descriptor_loss_func", None) is not None:
w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight(
model_cfg, w_policy)
loss_func["descriptor_loss"] = descriptor_loss_func.to(device)
loss_weight["w_desc"] = w_descriptor
return loss_func, loss_weight
def get_junction_loss_and_weight(model_cfg, global_w_policy):
""" Get the junction loss function and weight. """
junction_loss_cfg = model_cfg.get("junction_loss_cfg", {})
# Get the junction loss weight
w_policy = junction_loss_cfg.get("policy", global_w_policy)
if w_policy == "static":
w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32)
elif w_policy == "dynamic":
w_junc = nn.Parameter(
torch.tensor(model_cfg["w_junc"], dtype=torch.float32),
requires_grad=True)
else:
raise ValueError(
"[Error] Unknown weighting policy for junction loss weight.")
# Get the junction loss function
junc_loss_name = model_cfg.get("junction_loss_func", "superpoint")
if junc_loss_name == "superpoint":
junc_loss_func = JunctionDetectionLoss(model_cfg["grid_size"],
model_cfg["keep_border_valid"])
else:
raise ValueError("[Error] Not supported junction loss function.")
return w_junc, junc_loss_func
def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device):
""" Get the heatmap loss function and weight. """
heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {})
# Get the heatmap loss weight
w_policy = heatmap_loss_cfg.get("policy", global_w_policy)
if w_policy == "static":
w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32)
elif w_policy == "dynamic":
w_heatmap = nn.Parameter(
torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32),
requires_grad=True)
else:
raise ValueError(
"[Error] Unknown weighting policy for junction loss weight.")
# Get the corresponding heatmap loss based on the config
heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy")
if heatmap_loss_name == "cross_entropy":
# Get the heatmap class weight (always static)
heatmap_class_w = model_cfg.get("w_heatmap_class", 1.)
class_weight = torch.tensor(
np.array([1., heatmap_class_w])).to(torch.float).to(device)
heatmap_loss_func = HeatmapLoss(class_weight=class_weight)
else:
raise ValueError("[Error] Not supported heatmap loss function.")
return w_heatmap, heatmap_loss_func
def get_descriptor_loss_and_weight(model_cfg, global_w_policy):
""" Get the descriptor loss function and weight. """
descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {})
# Get the descriptor loss weight
w_policy = descriptor_loss_cfg.get("policy", global_w_policy)
if w_policy == "static":
w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32)
elif w_policy == "dynamic":
w_descriptor = nn.Parameter(torch.tensor(model_cfg["w_desc"],
dtype=torch.float32), requires_grad=True)
else:
raise ValueError(
"[Error] Unknown weighting policy for descriptor loss weight.")
# Get the descriptor loss function
descriptor_loss_name = model_cfg.get("descriptor_loss_func",
"regular_sampling")
if descriptor_loss_name == "regular_sampling":
descriptor_loss_func = TripletDescriptorLoss(
descriptor_loss_cfg["grid_size"],
descriptor_loss_cfg["dist_threshold"],
descriptor_loss_cfg["margin"])
else:
raise ValueError("[Error] Not supported descriptor loss function.")
return w_descriptor, descriptor_loss_func
def space_to_depth(input_tensor, grid_size):
""" PixelUnshuffle for pytorch. """
N, C, H, W = input_tensor.size()
# (N, C, H//bs, bs, W//bs, bs)
x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size)
# (N, bs, bs, C, H//bs, W//bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous()
# (N, C*bs^2, H//bs, W//bs)
x = x.view(N, C * (grid_size ** 2), H // grid_size, W // grid_size)
return x
def junction_detection_loss(junction_map, junc_predictions, valid_mask=None,
grid_size=8, keep_border=True):
""" Junction detection loss. """
# Convert junc_map to channel tensor
junc_map = space_to_depth(junction_map, grid_size)
map_shape = junc_map.shape[-2:]
batch_size = junc_map.shape[0]
dust_bin_label = torch.ones(
[batch_size, 1, map_shape[0],
map_shape[1]]).to(junc_map.device).to(torch.int)
junc_map = torch.cat([junc_map*2, dust_bin_label], dim=1)
labels = torch.argmax(
junc_map.to(torch.float) +
torch.distributions.Uniform(0, 0.1).sample(junc_map.shape).to(junc_map.device),
dim=1)
# Also convert the valid mask to channel tensor
valid_mask = (torch.ones(junction_map.shape) if valid_mask is None
else valid_mask)
valid_mask = space_to_depth(valid_mask, grid_size)
# Compute junction loss on the border patch or not
if keep_border:
valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int),
dim=1, keepdim=True) > 0
else:
valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int),
dim=1, keepdim=True) >= grid_size * grid_size
# Compute the classification loss
loss_func = nn.CrossEntropyLoss(reduction="none")
# The loss still need NCHW format
loss = loss_func(input=junc_predictions,
target=labels.to(torch.long))
# Weighted sum by the valid mask
loss_ = torch.sum(loss * torch.squeeze(valid_mask.to(torch.float),
dim=1), dim=[0, 1, 2])
loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float),
dim=1))
return loss_final
def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None,
class_weight=None):
""" Heatmap prediction loss. """
# Compute the classification loss on each pixel
if class_weight is None:
loss_func = nn.CrossEntropyLoss(reduction="none")
else:
loss_func = nn.CrossEntropyLoss(class_weight, reduction="none")
loss = loss_func(input=heatmap_pred,
target=torch.squeeze(heatmap_gt.to(torch.long), dim=1))
# Weighted sum by the valid mask
# Sum over H and W
loss_spatial_sum = torch.sum(loss * torch.squeeze(
valid_mask.to(torch.float), dim=1), dim=[1, 2])
valid_spatial_sum = torch.sum(torch.squeeze(valid_mask.to(torch.float32),
dim=1), dim=[1, 2])
# Mean to single scalar over batch dimension
loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum)
return loss
class JunctionDetectionLoss(nn.Module):
""" Junction detection loss. """
def __init__(self, grid_size, keep_border):
super(JunctionDetectionLoss, self).__init__()
self.grid_size = grid_size
self.keep_border = keep_border
def forward(self, prediction, target, valid_mask=None):
return junction_detection_loss(target, prediction, valid_mask,
self.grid_size, self.keep_border)
class HeatmapLoss(nn.Module):
""" Heatmap prediction loss. """
def __init__(self, class_weight):
super(HeatmapLoss, self).__init__()
self.class_weight = class_weight
def forward(self, prediction, target, valid_mask=None):
return heatmap_loss(target, prediction, valid_mask, self.class_weight)
class RegularizationLoss(nn.Module):
""" Module for regularization loss. """
def __init__(self):
super(RegularizationLoss, self).__init__()
self.name = "regularization_loss"
self.loss_init = torch.zeros([])
def forward(self, loss_weights):
# Place it to the same device
loss = self.loss_init.to(loss_weights["w_junc"].device)
for _, val in loss_weights.items():
if isinstance(val, nn.Parameter):
loss += val
return loss
def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices,
epoch, grid_size=8, dist_threshold=8,
init_dist_threshold=64, margin=1):
""" Regular triplet loss for descriptor learning. """
b_size, _, Hc, Wc = desc_pred1.size()
img_size = (Hc * grid_size, Wc * grid_size)
device = desc_pred1.device
# Extract valid keypoints
n_points = line_indices.size()[1]
valid_points = line_indices.bool().flatten()
n_correct_points = torch.sum(valid_points).item()
if n_correct_points == 0:
return torch.tensor(0., dtype=torch.float, device=device)
# Check which keypoints are too close to be matched
# dist_threshold is decreased at each epoch for easier training
dist_threshold = max(dist_threshold,
2 * init_dist_threshold // (epoch + 1))
dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold)
# Additionally ban negative mining along the same line
common_line_mask = get_common_line_mask(line_indices, valid_points)
dist_mask = dist_mask | common_line_mask
# Convert the keypoints to a grid suitable for interpolation
grid1 = keypoints_to_grid(points1, img_size)
grid2 = keypoints_to_grid(points2, img_size)
# Extract the descriptors
desc1 = F.grid_sample(desc_pred1, grid1).permute(
0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
desc1 = F.normalize(desc1, dim=1)
desc2 = F.grid_sample(desc_pred2, grid2).permute(
0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
desc2 = F.normalize(desc2, dim=1)
desc_dists = 2 - 2 * (desc1 @ desc2.t())
# Positive distance loss
pos_dist = torch.diag(desc_dists)
# Negative distance loss
max_dist = torch.tensor(4., dtype=torch.float, device=device)
desc_dists[
torch.arange(n_correct_points, dtype=torch.long),
torch.arange(n_correct_points, dtype=torch.long)] = max_dist
desc_dists[dist_mask] = max_dist
neg_dist = torch.min(torch.min(desc_dists, dim=1)[0],
torch.min(desc_dists, dim=0)[0])
triplet_loss = F.relu(margin + pos_dist - neg_dist)
return triplet_loss, grid1, grid2, valid_points
class TripletDescriptorLoss(nn.Module):
""" Triplet descriptor loss. """
def __init__(self, grid_size, dist_threshold, margin):
super(TripletDescriptorLoss, self).__init__()
self.grid_size = grid_size
self.init_dist_threshold = 64
self.dist_threshold = dist_threshold
self.margin = margin
def forward(self, desc_pred1, desc_pred2, points1,
points2, line_indices, epoch):
return self.descriptor_loss(desc_pred1, desc_pred2, points1,
points2, line_indices, epoch)
# The descriptor loss based on regularly sampled points along the lines
def descriptor_loss(self, desc_pred1, desc_pred2, points1,
points2, line_indices, epoch):
return torch.mean(triplet_loss(
desc_pred1, desc_pred2, points1, points2, line_indices, epoch,
self.grid_size, self.dist_threshold, self.init_dist_threshold,
self.margin)[0])
class TotalLoss(nn.Module):
""" Total loss summing junction, heatma, descriptor
and regularization losses. """
def __init__(self, loss_funcs, loss_weights, weighting_policy):
super(TotalLoss, self).__init__()
# Whether we need to compute the descriptor loss
self.compute_descriptors = "descriptor_loss" in loss_funcs.keys()
self.loss_funcs = loss_funcs
self.loss_weights = loss_weights
self.weighting_policy = weighting_policy
# Always add regularization loss (it will return zero if not used)
self.loss_funcs["reg_loss"] = RegularizationLoss().cuda()
def forward(self, junc_pred, junc_target, heatmap_pred,
heatmap_target, valid_mask=None):
""" Detection only loss. """
# Compute the junction loss
junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target,
valid_mask)
# Compute the heatmap loss
heatmap_loss = self.loss_funcs["heatmap_loss"](
heatmap_pred, heatmap_target, valid_mask)
# Compute the total loss.
if self.weighting_policy == "dynamic":
reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
total_loss = junc_loss * torch.exp(-self.loss_weights["w_junc"]) + \
heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"]) + \
reg_loss
return {
"total_loss": total_loss,
"junc_loss": junc_loss,
"heatmap_loss": heatmap_loss,
"reg_loss": reg_loss,
"w_junc": torch.exp(-self.loss_weights["w_junc"]).item(),
"w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(),
}
elif self.weighting_policy == "static":
total_loss = junc_loss * self.loss_weights["w_junc"] + \
heatmap_loss * self.loss_weights["w_heatmap"]
return {
"total_loss": total_loss,
"junc_loss": junc_loss,
"heatmap_loss": heatmap_loss
}
else:
raise ValueError("[Error] Unknown weighting policy.")
def forward_descriptors(self,
junc_map_pred1, junc_map_pred2, junc_map_target1,
junc_map_target2, heatmap_pred1, heatmap_pred2, heatmap_target1,
heatmap_target2, line_points1, line_points2, line_indices,
desc_pred1, desc_pred2, epoch, valid_mask1=None,
valid_mask2=None):
""" Loss for detection + description. """
# Compute junction loss
junc_loss = self.loss_funcs["junc_loss"](
torch.cat([junc_map_pred1, junc_map_pred2], dim=0),
torch.cat([junc_map_target1, junc_map_target2], dim=0),
torch.cat([valid_mask1, valid_mask2], dim=0)
)
# Get junction loss weight (dynamic or not)
if isinstance(self.loss_weights["w_junc"], nn.Parameter):
w_junc = torch.exp(-self.loss_weights["w_junc"])
else:
w_junc = self.loss_weights["w_junc"]
# Compute heatmap loss
heatmap_loss = self.loss_funcs["heatmap_loss"](
torch.cat([heatmap_pred1, heatmap_pred2], dim=0),
torch.cat([heatmap_target1, heatmap_target2], dim=0),
torch.cat([valid_mask1, valid_mask2], dim=0)
)
# Get heatmap loss weight (dynamic or not)
if isinstance(self.loss_weights["w_heatmap"], nn.Parameter):
w_heatmap = torch.exp(-self.loss_weights["w_heatmap"])
else:
w_heatmap = self.loss_weights["w_heatmap"]
# Compute the descriptor loss
descriptor_loss = self.loss_funcs["descriptor_loss"](
desc_pred1, desc_pred2, line_points1,
line_points2, line_indices, epoch)
# Get descriptor loss weight (dynamic or not)
if isinstance(self.loss_weights["w_desc"], nn.Parameter):
w_descriptor = torch.exp(-self.loss_weights["w_desc"])
else:
w_descriptor = self.loss_weights["w_desc"]
# Update the total loss
total_loss = (junc_loss * w_junc
+ heatmap_loss * w_heatmap
+ descriptor_loss * w_descriptor)
outputs = {
"junc_loss": junc_loss,
"heatmap_loss": heatmap_loss,
"w_junc": w_junc.item() \
if isinstance(w_junc, nn.Parameter) else w_junc,
"w_heatmap": w_heatmap.item() \
if isinstance(w_heatmap, nn.Parameter) else w_heatmap,
"descriptor_loss": descriptor_loss,
"w_desc": w_descriptor.item() \
if isinstance(w_descriptor, nn.Parameter) else w_descriptor
}
# Compute the regularization loss
reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
total_loss += reg_loss
outputs.update({
"reg_loss": reg_loss,
"total_loss": total_loss
})
return outputs