Spaces:
Runtime error
Runtime error
# Author: Bingxin Ke | |
# Last modified: 2024-02-22 | |
import torch | |
def get_loss(loss_name, **kwargs): | |
if "silog_mse" == loss_name: | |
criterion = SILogMSELoss(**kwargs) | |
elif "silog_rmse" == loss_name: | |
criterion = SILogRMSELoss(**kwargs) | |
elif "mse_loss" == loss_name: | |
criterion = torch.nn.MSELoss(**kwargs) | |
elif "l1_loss" == loss_name: | |
criterion = torch.nn.L1Loss(**kwargs) | |
elif "l1_loss_with_mask" == loss_name: | |
criterion = L1LossWithMask(**kwargs) | |
elif "mean_abs_rel" == loss_name: | |
criterion = MeanAbsRelLoss() | |
else: | |
raise NotImplementedError | |
return criterion | |
class L1LossWithMask: | |
def __init__(self, batch_reduction=False): | |
self.batch_reduction = batch_reduction | |
def __call__(self, depth_pred, depth_gt, valid_mask=None): | |
diff = depth_pred - depth_gt | |
if valid_mask is not None: | |
diff[~valid_mask] = 0 | |
n = valid_mask.sum((-1, -2)) | |
else: | |
n = depth_gt.shape[-2] * depth_gt.shape[-1] | |
loss = torch.sum(torch.abs(diff)) / n | |
if self.batch_reduction: | |
loss = loss.mean() | |
return loss | |
class MeanAbsRelLoss: | |
def __init__(self) -> None: | |
# super().__init__() | |
pass | |
def __call__(self, pred, gt): | |
diff = pred - gt | |
rel_abs = torch.abs(diff / gt) | |
loss = torch.mean(rel_abs, dim=0) | |
return loss | |
class SILogMSELoss: | |
def __init__(self, lamb, log_pred=True, batch_reduction=True): | |
"""Scale Invariant Log MSE Loss | |
Args: | |
lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss | |
log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred | |
""" | |
super(SILogMSELoss, self).__init__() | |
self.lamb = lamb | |
self.pred_in_log = log_pred | |
self.batch_reduction = batch_reduction | |
def __call__(self, depth_pred, depth_gt, valid_mask=None): | |
log_depth_pred = ( | |
depth_pred if self.pred_in_log else torch.log(torch.clip(depth_pred, 1e-8)) | |
) | |
log_depth_gt = torch.log(depth_gt) | |
diff = log_depth_pred - log_depth_gt | |
if valid_mask is not None: | |
diff[~valid_mask] = 0 | |
n = valid_mask.sum((-1, -2)) | |
else: | |
n = depth_gt.shape[-2] * depth_gt.shape[-1] | |
diff2 = torch.pow(diff, 2) | |
first_term = torch.sum(diff2, (-1, -2)) / n | |
second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) | |
loss = first_term - second_term | |
if self.batch_reduction: | |
loss = loss.mean() | |
return loss | |
class SILogRMSELoss: | |
def __init__(self, lamb, alpha, log_pred=True): | |
"""Scale Invariant Log RMSE Loss | |
Args: | |
lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss | |
alpha: | |
log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred | |
""" | |
super(SILogRMSELoss, self).__init__() | |
self.lamb = lamb | |
self.alpha = alpha | |
self.pred_in_log = log_pred | |
def __call__(self, depth_pred, depth_gt, valid_mask): | |
log_depth_pred = depth_pred if self.pred_in_log else torch.log(depth_pred) | |
log_depth_gt = torch.log(depth_gt) | |
# borrowed from https://github.com/aliyun/NeWCRFs | |
# diff = log_depth_pred[valid_mask] - log_depth_gt[valid_mask] | |
# return torch.sqrt((diff ** 2).mean() - self.lamb * (diff.mean() ** 2)) * self.alpha | |
diff = log_depth_pred - log_depth_gt | |
if valid_mask is not None: | |
diff[~valid_mask] = 0 | |
n = valid_mask.sum((-1, -2)) | |
else: | |
n = depth_gt.shape[-2] * depth_gt.shape[-1] | |
diff2 = torch.pow(diff, 2) | |
first_term = torch.sum(diff2, (-1, -2)) / n | |
second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) | |
loss = torch.sqrt(first_term - second_term).mean() * self.alpha | |
return loss | |