Spaces:
Sleeping
Sleeping
frameworkeval
Browse files- frameworkeval.py +56 -0
frameworkeval.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.models import vgg16
|
5 |
+
from torchmetrics.functional import structural_similarity_index_measure
|
6 |
+
from facenet_pytorch import InceptionResnetV1
|
7 |
+
from losseval.denormalize import denormalize_bin, denormalize_tr, denormalize_ar
|
8 |
+
|
9 |
+
class DF(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super(DF, self).__init__()
|
12 |
+
self.mse_weight = 0.25
|
13 |
+
self.perceptual_weight = 0.25
|
14 |
+
self.ssim_weight = 0.25
|
15 |
+
self.idsim_weight = 0.25
|
16 |
+
self.device = "cuda"
|
17 |
+
self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
|
18 |
+
self.facenet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
|
19 |
+
for param in self.facenet.parameters():
|
20 |
+
param.requires_grad = False # Freeze the model
|
21 |
+
self.cosloss = nn.CosineEmbeddingLoss()
|
22 |
+
|
23 |
+
def perceptual_loss(self, real, fake):
|
24 |
+
with torch.no_grad(): # VGG is frozen during training
|
25 |
+
real_features = self.vgg(real)
|
26 |
+
fake_features = self.vgg(fake)
|
27 |
+
return F.mse_loss(real_features, fake_features)
|
28 |
+
|
29 |
+
def idsimilarity(self, real, fake):
|
30 |
+
with torch.no_grad():
|
31 |
+
# Extract embeddings
|
32 |
+
input_embed = self.facenet(real).to(device)
|
33 |
+
generated_embed = self.facenet(fake).to(device)
|
34 |
+
# Compute cosine similarity loss
|
35 |
+
target = torch.ones(input_embed.size(0)).to(real.device) # Target = 1 (maximize similarity)
|
36 |
+
return self.cosloss(input_embed, generated_embed, target)
|
37 |
+
|
38 |
+
def forward(self, r, f):
|
39 |
+
real = denormalize_bin(r) #[-1,1] to [0,1]
|
40 |
+
fake = denormalize_bin(f)
|
41 |
+
mse_loss = F.mse_loss(real, fake)
|
42 |
+
perceptual_loss = self.perceptual_loss(real, fake)
|
43 |
+
idsim_loss = self.idsimilarity(real, fake)
|
44 |
+
ssim = structural_similarity_index_measure(fake, real)
|
45 |
+
ssim_loss = 1 - ssim
|
46 |
+
id_si = 1 - idsim_loss
|
47 |
+
|
48 |
+
total_loss = (self.mse_weight * mse_loss) + (self.perceptual_weight * perceptual_loss) + (self.idsim_weight * idsim_loss) + (self.ssim_weight * ssim_loss)
|
49 |
+
components = {
|
50 |
+
"MSE Loss": mse_loss.item(),
|
51 |
+
"Perceptual Loss": perceptual_loss.item(),
|
52 |
+
"ID-SIM Loss": idsim_loss.item(),
|
53 |
+
"SSIM Loss": ssim_loss.item()
|
54 |
+
}
|
55 |
+
|
56 |
+
return total_loss, components
|