Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from math import exp | |
from config import Config | |
class ContourLoss(torch.nn.Module): | |
def __init__(self): | |
super(ContourLoss, self).__init__() | |
def forward(self, pred, target, weight=10): | |
''' | |
target, pred: tensor of shape (B, C, H, W), where target[:,:,region_in_contour] == 1, | |
target[:,:,region_out_contour] == 0. | |
weight: scalar, length term weight. | |
''' | |
# length term | |
delta_r = pred[:,:,1:,:] - pred[:,:,:-1,:] # horizontal gradient (B, C, H-1, W) | |
delta_c = pred[:,:,:,1:] - pred[:,:,:,:-1] # vertical gradient (B, C, H, W-1) | |
delta_r = delta_r[:,:,1:,:-2]**2 # (B, C, H-2, W-2) | |
delta_c = delta_c[:,:,:-2,1:]**2 # (B, C, H-2, W-2) | |
delta_pred = torch.abs(delta_r + delta_c) | |
epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice. | |
length = torch.mean(torch.sqrt(delta_pred + epsilon)) # eq.(11) in the paper, mean is used instead of sum. | |
c_in = torch.ones_like(pred) | |
c_out = torch.zeros_like(pred) | |
region_in = torch.mean( pred * (target - c_in )**2 ) # equ.(12) in the paper, mean is used instead of sum. | |
region_out = torch.mean( (1-pred) * (target - c_out)**2 ) | |
region = region_in + region_out | |
loss = weight * length + region | |
return loss | |
class IoULoss(torch.nn.Module): | |
def __init__(self): | |
super(IoULoss, self).__init__() | |
def forward(self, pred, target): | |
b = pred.shape[0] | |
IoU = 0.0 | |
for i in range(0, b): | |
# compute the IoU of the foreground | |
Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :]) | |
Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1 | |
IoU1 = Iand1 / Ior1 | |
# IoU loss is (1-IoU1) | |
IoU = IoU + (1-IoU1) | |
# return IoU/b | |
return IoU | |
class StructureLoss(torch.nn.Module): | |
def __init__(self): | |
super(StructureLoss, self).__init__() | |
def forward(self, pred, target): | |
weit = 1+5*torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15)-target) | |
wbce = F.binary_cross_entropy_with_logits(pred, target, reduction='none') | |
wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) | |
pred = torch.sigmoid(pred) | |
inter = ((pred * target) * weit).sum(dim=(2, 3)) | |
union = ((pred + target) * weit).sum(dim=(2, 3)) | |
wiou = 1-(inter+1)/(union-inter+1) | |
return (wbce+wiou).mean() | |
class PatchIoULoss(torch.nn.Module): | |
def __init__(self): | |
super(PatchIoULoss, self).__init__() | |
self.iou_loss = IoULoss() | |
def forward(self, pred, target): | |
win_y, win_x = 64, 64 | |
iou_loss = 0. | |
for anchor_y in range(0, target.shape[0], win_y): | |
for anchor_x in range(0, target.shape[1], win_y): | |
patch_pred = pred[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x] | |
patch_target = target[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x] | |
patch_iou_loss = self.iou_loss(patch_pred, patch_target) | |
iou_loss += patch_iou_loss | |
return iou_loss | |
class ThrReg_loss(torch.nn.Module): | |
def __init__(self): | |
super(ThrReg_loss, self).__init__() | |
def forward(self, pred, gt=None): | |
return torch.mean(1 - ((pred - 0) ** 2 + (pred - 1) ** 2)) | |
class ClsLoss(nn.Module): | |
""" | |
Auxiliary classification loss for each refined class output. | |
""" | |
def __init__(self): | |
super(ClsLoss, self).__init__() | |
self.config = Config() | |
self.lambdas_cls = self.config.lambdas_cls | |
self.criterions_last = { | |
'ce': nn.CrossEntropyLoss() | |
} | |
def forward(self, preds, gt): | |
loss = 0. | |
for _, pred_lvl in enumerate(preds): | |
if pred_lvl is None: | |
continue | |
for criterion_name, criterion in self.criterions_last.items(): | |
loss += criterion(pred_lvl, gt) * self.lambdas_cls[criterion_name] | |
return loss | |
class PixLoss(nn.Module): | |
""" | |
Pixel loss for each refined map output. | |
""" | |
def __init__(self): | |
super(PixLoss, self).__init__() | |
self.config = Config() | |
self.lambdas_pix_last = self.config.lambdas_pix_last | |
self.criterions_last = {} | |
if 'bce' in self.lambdas_pix_last and self.lambdas_pix_last['bce']: | |
self.criterions_last['bce'] = nn.BCELoss() | |
if 'iou' in self.lambdas_pix_last and self.lambdas_pix_last['iou']: | |
self.criterions_last['iou'] = IoULoss() | |
if 'iou_patch' in self.lambdas_pix_last and self.lambdas_pix_last['iou_patch']: | |
self.criterions_last['iou_patch'] = PatchIoULoss() | |
if 'ssim' in self.lambdas_pix_last and self.lambdas_pix_last['ssim']: | |
self.criterions_last['ssim'] = SSIMLoss() | |
if 'mae' in self.lambdas_pix_last and self.lambdas_pix_last['mae']: | |
self.criterions_last['mae'] = nn.L1Loss() | |
if 'mse' in self.lambdas_pix_last and self.lambdas_pix_last['mse']: | |
self.criterions_last['mse'] = nn.MSELoss() | |
if 'reg' in self.lambdas_pix_last and self.lambdas_pix_last['reg']: | |
self.criterions_last['reg'] = ThrReg_loss() | |
if 'cnt' in self.lambdas_pix_last and self.lambdas_pix_last['cnt']: | |
self.criterions_last['cnt'] = ContourLoss() | |
if 'structure' in self.lambdas_pix_last and self.lambdas_pix_last['structure']: | |
self.criterions_last['structure'] = StructureLoss() | |
def forward(self, scaled_preds, gt): | |
loss = 0. | |
for _, pred_lvl in enumerate(scaled_preds): | |
if pred_lvl.shape != gt.shape: | |
pred_lvl = nn.functional.interpolate(pred_lvl, size=gt.shape[2:], mode='bilinear', align_corners=True) | |
for criterion_name, criterion in self.criterions_last.items(): | |
_loss = criterion(pred_lvl.sigmoid(), gt) * self.lambdas_pix_last[criterion_name] | |
loss += _loss | |
# print(criterion_name, _loss.item()) | |
return loss | |
class SSIMLoss(torch.nn.Module): | |
def __init__(self, window_size=11, size_average=True): | |
super(SSIMLoss, self).__init__() | |
self.window_size = window_size | |
self.size_average = size_average | |
self.channel = 1 | |
self.window = create_window(window_size, self.channel) | |
def forward(self, img1, img2): | |
(_, channel, _, _) = img1.size() | |
if channel == self.channel and self.window.data.type() == img1.data.type(): | |
window = self.window | |
else: | |
window = create_window(self.window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
self.window = window | |
self.channel = channel | |
return 1 - (1 + _ssim(img1, img2, window, self.window_size, channel, self.size_average)) / 2 | |
def gaussian(window_size, sigma): | |
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) | |
return gauss/gauss.sum() | |
def create_window(window_size, channel): | |
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
return window | |
def _ssim(img1, img2, window, window_size, channel, size_average=True): | |
mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel) | |
mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel) | |
mu1_sq = mu1.pow(2) | |
mu2_sq = mu2.pow(2) | |
mu1_mu2 = mu1*mu2 | |
sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq | |
sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq | |
sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2 | |
C1 = 0.01**2 | |
C2 = 0.03**2 | |
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) | |
if size_average: | |
return ssim_map.mean() | |
else: | |
return ssim_map.mean(1).mean(1).mean(1) | |
def SSIM(x, y): | |
C1 = 0.01 ** 2 | |
C2 = 0.03 ** 2 | |
mu_x = nn.AvgPool2d(3, 1, 1)(x) | |
mu_y = nn.AvgPool2d(3, 1, 1)(y) | |
mu_x_mu_y = mu_x * mu_y | |
mu_x_sq = mu_x.pow(2) | |
mu_y_sq = mu_y.pow(2) | |
sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq | |
sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq | |
sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y | |
SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) | |
SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2) | |
SSIM = SSIM_n / SSIM_d | |
return torch.clamp((1 - SSIM) / 2, 0, 1) | |
def saliency_structure_consistency(x, y): | |
ssim = torch.mean(SSIM(x,y)) | |
return ssim | |