import torch import torch.nn as nn import torch.nn.functional as F import torchvision import swapae.util as util from .stylegan2_layers import Downsample def gan_loss(pred, should_be_classified_as_real): bs = pred.size(0) if should_be_classified_as_real: return F.softplus(-pred).view(bs, -1).mean(dim=1) else: return F.softplus(pred).view(bs, -1).mean(dim=1) def feature_matching_loss(xs, ys, equal_weights=False, num_layers=6): loss = 0.0 for i, (x, y) in enumerate(zip(xs[:num_layers], ys[:num_layers])): if equal_weights: weight = 1.0 / min(num_layers, len(xs)) else: weight = 1 / (2 ** (min(num_layers, len(xs)) - i)) loss = loss + (x - y).abs().flatten(1).mean(1) * weight return loss class IntraImageNCELoss(nn.Module): def __init__(self, opt): super().__init__() self.opt = opt self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean') def forward(self, query, target): num_locations = min(query.size(2) * query.size(3), self.opt.intraimage_num_locations) bs = query.size(0) patch_ids = torch.randperm(num_locations, device=query.device) query = query.flatten(2, 3) target = target.flatten(2, 3) # both query and target are of size B x C x N query = query[:, :, patch_ids] target = target[:, :, patch_ids] cosine_similarity = torch.bmm(query.transpose(1, 2), target) cosine_similarity = cosine_similarity.flatten(0, 1) target_label = torch.arange(num_locations, dtype=torch.long, device=query.device).repeat(bs) loss = self.cross_entropy_loss(cosine_similarity / 0.07, target_label) return loss class VGG16Loss(torch.nn.Module): def __init__(self): super().__init__() self.vgg_convs = torchvision.models.vgg16(pretrained=True).features self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406])[None, :, None, None] - 0.5) self.register_buffer('stdev', torch.tensor([0.229, 0.224, 0.225])[None, :, None, None] * 2) self.downsample = Downsample([1, 2, 1], factor=2) def copy_section(self, source, start, end): slice = torch.nn.Sequential() for i in range(start, end): slice.add_module(str(i), source[i]) return slice def vgg_forward(self, x): x = (x - self.mean) / self.stdev features = [] for name, layer in self.vgg_convs.named_children(): if "MaxPool2d" == type(layer).__name__: features.append(x) if len(features) == 3: break x = self.downsample(x) else: x = layer(x) return features def forward(self, x, y): y = y.detach() loss = 0 weights = [1 / 32, 1 / 16, 1 / 8, 1 / 4, 1.0] #weights = [1] * 5 total_weights = 0.0 for i, (xf, yf) in enumerate(zip(self.vgg_forward(x), self.vgg_forward(y))): loss += F.l1_loss(xf, yf) * weights[i] total_weights += weights[i] return loss / total_weights class NCELoss(nn.Module): def __init__(self): super().__init__() self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean') def forward(self, query, target, negatives): query = util.normalize(query.flatten(1)) target = util.normalize(target.flatten(1)) negatives = util.normalize(negatives.flatten(1)) bs = query.size(0) sim_pos = (query * target).sum(dim=1, keepdim=True) sim_neg = torch.mm(query, negatives.transpose(0, 1)) all_similarity = torch.cat([sim_pos, sim_neg], axis=1) / 0.07 #sim_target = util.compute_similarity_logit(query, target) #sim_target = torch.mm(query, target.transpose(0, 1)) / 0.07 #sim_query = util.compute_similarity_logit(query, query) #util.set_diag_(sim_query, -20.0) #all_similarity = torch.cat([sim_target, sim_query], axis=1) #target_label = torch.arange(bs, dtype=torch.long, # device=query.device) target_label = torch.zeros(bs, dtype=torch.long, device=query.device) loss = self.cross_entropy_loss(all_similarity, target_label) return loss class ScaleInvariantReconstructionLoss(nn.Module): def forward(self, query, target): query_flat = query.transpose(1, 3) target_flat = target.transpose(1, 3) dist = 1.0 - torch.bmm( query_flat[:, :, :, None, :].flatten(0, 2), target_flat[:, :, :, :, None].flatten(0, 2), ) target_spatially_flat = target.flatten(1, 2) num_samples = min(target_spatially_flat.size(1), 64) random_indices = torch.randperm(num_samples, dtype=torch.long, device=target.device) randomly_sampled = target_spatially_flat[:, random_indices] random_indices = torch.randperm(num_samples, dtype=torch.long, device=target.device) another_random_sample = target_spatially_flat[:, random_indices] random_similarity = torch.bmm( randomly_sampled[:, :, None, :].flatten(0, 1), torch.flip(another_random_sample, [0])[:, :, :, None].flatten(0, 1) ) return dist.mean() + random_similarity.clamp(min=0.0).mean()