NamedCurves / utils /setup_criterion.py
davidserra9's picture
First commit from github repo
117183e verified
raw
history blame
1.05 kB
from torch import nn
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
class BackboneL2SSIMLoss(nn.Module):
def __init__(self, ssim_window_size=5, alpha=0.5):
super(BackboneL2SSIMLoss, self).__init__()
self.ssim_window_size = ssim_window_size
self.alpha = alpha
self.ssim_loss = SSIM(kernel_size=ssim_window_size).cuda()
self.l2_loss = nn.MSELoss()
def forward(self, backbone, prediction, target):
ssim_loss_pred = (1.0 - self.ssim_loss(prediction, target))
l2_loss_pred = self.l2_loss(prediction, target)
l2_loss_backbone = self.l2_loss(backbone, target)
return self.alpha*l2_loss_backbone + l2_loss_pred + ssim_loss_pred
def get_criterion(criterion_config):
criterion_type = criterion_config.type
if criterion_type == 'backbone-L2-SSIM':
return BackboneL2SSIMLoss(**criterion_config['params'])
#TODO: Add more criterion types here (L1, L2 ...)
else:
raise ValueError(f"Unsupported criterion type: {criterion_type}")