Jannat24 commited on
Commit
bbb5f33
·
1 Parent(s): da52bc4

frameworkeval

Browse files
Files changed (1) hide show
  1. frameworkeval.py +56 -0
frameworkeval.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import vgg16
5
+ from torchmetrics.functional import structural_similarity_index_measure
6
+ from facenet_pytorch import InceptionResnetV1
7
+ from losseval.denormalize import denormalize_bin, denormalize_tr, denormalize_ar
8
+
9
+ class DF(nn.Module):
10
+ def __init__(self):
11
+ super(DF, self).__init__()
12
+ self.mse_weight = 0.25
13
+ self.perceptual_weight = 0.25
14
+ self.ssim_weight = 0.25
15
+ self.idsim_weight = 0.25
16
+ self.device = "cuda"
17
+ self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
18
+ self.facenet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
19
+ for param in self.facenet.parameters():
20
+ param.requires_grad = False # Freeze the model
21
+ self.cosloss = nn.CosineEmbeddingLoss()
22
+
23
+ def perceptual_loss(self, real, fake):
24
+ with torch.no_grad(): # VGG is frozen during training
25
+ real_features = self.vgg(real)
26
+ fake_features = self.vgg(fake)
27
+ return F.mse_loss(real_features, fake_features)
28
+
29
+ def idsimilarity(self, real, fake):
30
+ with torch.no_grad():
31
+ # Extract embeddings
32
+ input_embed = self.facenet(real).to(device)
33
+ generated_embed = self.facenet(fake).to(device)
34
+ # Compute cosine similarity loss
35
+ target = torch.ones(input_embed.size(0)).to(real.device) # Target = 1 (maximize similarity)
36
+ return self.cosloss(input_embed, generated_embed, target)
37
+
38
+ def forward(self, r, f):
39
+ real = denormalize_bin(r) #[-1,1] to [0,1]
40
+ fake = denormalize_bin(f)
41
+ mse_loss = F.mse_loss(real, fake)
42
+ perceptual_loss = self.perceptual_loss(real, fake)
43
+ idsim_loss = self.idsimilarity(real, fake)
44
+ ssim = structural_similarity_index_measure(fake, real)
45
+ ssim_loss = 1 - ssim
46
+ id_si = 1 - idsim_loss
47
+
48
+ total_loss = (self.mse_weight * mse_loss) + (self.perceptual_weight * perceptual_loss) + (self.idsim_weight * idsim_loss) + (self.ssim_weight * ssim_loss)
49
+ components = {
50
+ "MSE Loss": mse_loss.item(),
51
+ "Perceptual Loss": perceptual_loss.item(),
52
+ "ID-SIM Loss": idsim_loss.item(),
53
+ "SSIM Loss": ssim_loss.item()
54
+ }
55
+
56
+ return total_loss, components