Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,050 Bytes
117183e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from torch import nn
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
class BackboneL2SSIMLoss(nn.Module):
def __init__(self, ssim_window_size=5, alpha=0.5):
super(BackboneL2SSIMLoss, self).__init__()
self.ssim_window_size = ssim_window_size
self.alpha = alpha
self.ssim_loss = SSIM(kernel_size=ssim_window_size).cuda()
self.l2_loss = nn.MSELoss()
def forward(self, backbone, prediction, target):
ssim_loss_pred = (1.0 - self.ssim_loss(prediction, target))
l2_loss_pred = self.l2_loss(prediction, target)
l2_loss_backbone = self.l2_loss(backbone, target)
return self.alpha*l2_loss_backbone + l2_loss_pred + ssim_loss_pred
def get_criterion(criterion_config):
criterion_type = criterion_config.type
if criterion_type == 'backbone-L2-SSIM':
return BackboneL2SSIMLoss(**criterion_config['params'])
#TODO: Add more criterion types here (L1, L2 ...)
else:
raise ValueError(f"Unsupported criterion type: {criterion_type}") |