Jannat24 commited on
Commit
c9224f7
·
verified ·
1 Parent(s): b9119c8

2025_march16

Browse files
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import shutil
4
+ import requests
5
+ import numpy as np
6
+ from PIL import Image, ImageOps
7
+ import math
8
+ import matplotlib.pyplot as plt
9
+ import pickle
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torchvision.transforms as T
14
+ import torchvision.transforms.functional as TF
15
+ from torch.utils.checkpoint import checkpoint
16
+ from torchvision.models import vgg16
17
+ from torchmetrics.image.fid import FrechetInceptionDistance
18
+ from torchmetrics.functional import structural_similarity_index_measure
19
+ from facenet_pytorch import InceptionResnetV1
20
+ from taming.models.vqgan import VQModel
21
+ from omegaconf import OmegaConf
22
+ from taming.models.vqgan import GumbelVQ
23
+ import gradio as gr
24
+ from modules.finetunedvqgan import Generator
25
+ from modules.modelz import DeepfakeToSourceTransformer
26
+ from modules.frameworkeval import DF
27
+ from modules.segmentface import FaceSegmenter
28
+ from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ ##________________________Transformation______________________________
33
+
34
+ transform = T.Compose([
35
+ T.Resize((256, 256)),
36
+ T.ToTensor(),
37
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) # Normalize to [-1, 1]
38
+
39
+ #_________________Define:Gradio Function________________________
40
+
41
+ def gen_sources(deepfake_img):
42
+ #----------------DeepFake Face Segmentation-----------------
43
+ deepfake_seg = segmenter.segment_face(deepfake_img)
44
+ config_path = "./models/config.yaml"
45
+ #------------Initialize:Decoder-F------------------------
46
+ checkpoint_path_f = "./models/model_vaq1_ff.pth"
47
+ checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
48
+ model_vaq_f = Generator(config_path, device)
49
+ model_vaq_f = model_vaq_f.load_state_dict(checkpoint_f, strict=True)
50
+ model_vaq_f.eval()
51
+ #------------Initialize:Decoder-G------------------------
52
+ checkpoint_path_g = "./models/model_vaq2_gg.pth"
53
+ checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
54
+ model_vaq_g = Generator(config_path, device)
55
+ model_vaq_g = model_vaq_g.load_state_dict(checkpoint_g, strict=True)
56
+ model_vaq_g.eval()
57
+ ##------------------------Initialize Model-F-------------------------------------
58
+ model_z1 = DeepfakeToSourceTransformer().to(device)
59
+ model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True)
60
+ model_z1.eval()
61
+ ##------------------------Initialize Model-G-------------------------------------
62
+ model_z2 = DeepfakeToSourceTransformer().to(device)
63
+ model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth",map_location=device),strict=True)
64
+ model_z2.eval()
65
+ ##--------------------Initialize:Evaluation---------------------------------------
66
+ criterion = DF()
67
+ ##----------------------Initialize:Face Segmentation----------------------------------
68
+ segmenter = FaceSegmenter(threshold=0.5)
69
+
70
+ ##----------------------Operation-------------------------------------------------
71
+ with torch.no_grad():
72
+ # Load and preprocess input image
73
+ img = Image.open(deepfake_img).convert('RGB')
74
+ segimg = Image.open(deepfake_seg).convert('RGB')
75
+ df_img = transform(img).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
76
+ seg_img = transform(segimg).unsqueeze(0).to(device)
77
+
78
+ # Calculate quantized_block for all images
79
+ z_df, _, _ = model_vaq_f.encode(df_img)
80
+ z_seg, _, _ = model_vaq_g.encode(seg_img)
81
+ rec_z_img1 = model_z1(z_df)
82
+ rec_z_img2 = model_z2(z_seg)
83
+ rec_img1 = model_vaq_f.decode(rec_z_img1)
84
+ rec_img2 = model_vaq_g.decode(rec_z_img2)
85
+ rec_img1 = rec_img1.squeeze(0)
86
+ rec_img2 = rec_img2.squeeze(0)
87
+ rec_img1_pil = T.ToPILImage()(rec_img1)
88
+ rec_img2_pil = T.ToPILImage()(rec_img2)
89
+
90
+ # Save PIL images to in-memory buffers
91
+ buffer1 = BytesIO()
92
+ buffer2 = BytesIO()
93
+ rec_img1_pil.save(buffer1, format="PNG")
94
+ rec_img2_pil.save(buffer2, format="PNG")
95
+
96
+ # Pass buffers to Gradio client
97
+ result = client.predict(
98
+ target=file(buffer1),
99
+ source=file(buffer2), slider=100, adv_slider=100,
100
+ settings=["Adversarial Defense"], api_name="/run_inference"
101
+ )
102
+
103
+ # Load result and compute loss
104
+ dfimage_pil = Image.open(result) # Open the resulting image
105
+ buffer3 = BytesIO()
106
+ dfimage_pil.save(buffer3, format="PNG")
107
+ rec_df = transform(Image.open(buffer3)).unsqueeze(0).to(device)
108
+ rec_loss,_ = criterion(df_img, rec_df)
109
+
110
+ return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3))
111
+
112
+ #________________________Create the Gradio interface_________________________________
113
+ interface = gr.Interface(
114
+ fn=gen_sources,
115
+ inputs=gr.Image(type="pil", label="Input Image"),
116
+ outputs=[
117
+ gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"),
118
+ gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"),
119
+ gr.Image(type="pil", label="Reconstructed Deepfake Image"),
120
+ gr.Number(label="Reconstruction Loss")
121
+ ],
122
+ examples = ["./images/df1.jpg","./images/df2.jpg","./images/df3.jpg","./images/df4.jpg"],
123
+ theme = gr.themes.Soft(),
124
+ title="Uncovering Deepfake Image for Identifying Source Images",
125
+ description="Upload an DeepFake image.",
126
+ )
127
+
128
+ interface.launch(debug=True)
images/df1.jpg ADDED
images/df2.jpg ADDED
images/df3.jpg ADDED
images/df4.jpg ADDED
modules/.ipynb_checkpoints/denormalize-checkpoint.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+ #------------------Denormalization---------------------------------------------
5
+ def denormalize_bin(tensor):
6
+ tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
7
+ tr = tr.add(1).div(2) # Shift to [0, 1]
8
+ return tr
9
+
10
+ def denormalize_tr(tensor):
11
+ tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
12
+ tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255]
13
+ tr = tr.byte() # Convert the tensor to uint8
14
+ return tr
15
+
16
+ def denormalize_ar(tensor):
17
+ tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
18
+ tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255]
19
+ tr = tr.byte() # Convert the tensor to uint8
20
+ arr = tr.permute(0, 2, 3, 1).cpu().detach().numpy() # Convert to (N, H, W, C) and numpy array
21
+ return arr
modules/.ipynb_checkpoints/finetunedvqgan-checkpoint.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.checkpoint import checkpoint
3
+ from taming.models.vqgan import VQModel
4
+ from omegaconf import OmegaConf
5
+ from taming.models.vqgan import GumbelVQ
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ class Generator:
10
+ def __init__(self, config_path, device=device):
11
+ self.config_path = config_path
12
+ self.device = device
13
+
14
+ def load_models(self):
15
+ # Load configuration
16
+ config = OmegaConf.load(self.config_path)
17
+ # Extract parameters specific to GumbelVQ
18
+ vq_params = config.model.params
19
+ # Initialize the GumbelVQ models
20
+ model_vaq = GumbelVQ(
21
+ ddconfig=vq_params.ddconfig,
22
+ lossconfig=vq_params.lossconfig,
23
+ n_embed=vq_params.n_embed,
24
+ embed_dim=vq_params.embed_dim,
25
+ kl_weight=vq_params.kl_weight,
26
+ temperature_scheduler_config=vq_params.temperature_scheduler_config,
27
+ ).to(self.device)
28
+
29
+ return model_vaq
30
+
31
+
modules/.ipynb_checkpoints/frameworkeval-checkpoint.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import vgg16
5
+ from torchmetrics.functional import structural_similarity_index_measure
6
+ from facenet_pytorch import InceptionResnetV1
7
+ from denormalize import denormalize_bin, denormalize_tr, denormalize_ar
8
+
9
+ class DF(nn.Module):
10
+ def __init__(self):
11
+ super(DF, self).__init__()
12
+ self.mse_weight = 0.25
13
+ self.perceptual_weight = 0.25
14
+ self.ssim_weight = 0.25
15
+ self.idsim_weight = 0.25
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
18
+ self.facenet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
19
+ for param in self.facenet.parameters():
20
+ param.requires_grad = False # Freeze the model
21
+ self.cosloss = nn.CosineEmbeddingLoss()
22
+
23
+ def perceptual_loss(self, real, fake):
24
+ with torch.no_grad(): # VGG is frozen during training
25
+ real_features = self.vgg(real)
26
+ fake_features = self.vgg(fake)
27
+ return F.mse_loss(real_features, fake_features)
28
+
29
+ def idsimilarity(self, real, fake):
30
+ with torch.no_grad():
31
+ # Extract embeddings
32
+ input_embed = self.facenet(real).to(device)
33
+ generated_embed = self.facenet(fake).to(device)
34
+ # Compute cosine similarity loss
35
+ target = torch.ones(input_embed.size(0)).to(real.device) # Target = 1 (maximize similarity)
36
+ return self.cosloss(input_embed, generated_embed, target)
37
+
38
+ def forward(self, r, f):
39
+ real = denormalize_bin(r) #[-1,1] to [0,1]
40
+ fake = denormalize_bin(f)
41
+ mse_loss = F.mse_loss(real, fake)
42
+ perceptual_loss = self.perceptual_loss(real, fake)
43
+ idsim_loss = self.idsimilarity(real, fake)
44
+ ssim = structural_similarity_index_measure(fake, real)
45
+ ssim_loss = 1 - ssim
46
+ id_si = 1 - idsim_loss
47
+
48
+ total_loss = (self.mse_weight * mse_loss) + (self.perceptual_weight * perceptual_loss) + (self.idsim_weight * idsim_loss) + (self.ssim_weight * ssim_loss)
49
+ components = {
50
+ "MSE Loss": mse_loss.item(),
51
+ "Perceptual Loss": perceptual_loss.item(),
52
+ "ID-SIM Loss": idsim_loss.item(),
53
+ "SSIM Loss": ssim_loss.item()
54
+ }
55
+
56
+ return total_loss, components
modules/.ipynb_checkpoints/modelz-checkpoint.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ ##_____________________Define:MODEL-F & MODEL-G_________________
8
+
9
+ # Positional Encoding
10
+ class PositionalEncoding(nn.Module):
11
+ def __init__(self, d_model, max_len=1024):
12
+ super(PositionalEncoding, self).__init__()
13
+ self.dropout = nn.Dropout(0.1)
14
+ position = torch.arange(max_len).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16
+ pe = torch.zeros(max_len, d_model)
17
+ pe[:, 0::2] = torch.sin(position * div_term)
18
+ pe[:, 1::2] = torch.cos(position * div_term)
19
+ self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
20
+
21
+ def forward(self, x):
22
+ x = x + self.pe[:, :x.size(1)]
23
+ return self.dropout(x)
24
+
25
+ # Transformer Encoder
26
+ class TransformerEncoder(nn.Module):
27
+ def __init__(self, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1):
28
+ super(TransformerEncoder, self).__init__()
29
+ self.positional_encoding = PositionalEncoding(d_model)
30
+ self.encoder = nn.TransformerEncoder(
31
+ nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout,batch_first=True),
32
+ num_layers=num_layers
33
+ )
34
+
35
+ def preprocess_latent(self, Z):
36
+ batch_size, channels, height, width = Z.shape # (batch_size, 256, 32, 32)
37
+ seq_len = height * width
38
+ Z = Z.permute(0, 2, 3, 1).reshape(batch_size, seq_len, channels) # (batch_size, 1024, 256)
39
+ return Z
40
+
41
+ def postprocess_latent(self, Z):
42
+ batch_size, seq_len, channels = Z.shape # (batch_size, 1024, 256)
43
+ height = width = int(math.sqrt(seq_len))
44
+ Z = Z.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) # (batch_size, 256, 32, 32)
45
+ return Z
46
+
47
+ def forward(self, Z):
48
+ Z = self.preprocess_latent(Z)
49
+ Z = self.positional_encoding(Z)
50
+ Z = self.encoder(Z)
51
+ Z = self.postprocess_latent(Z)
52
+ return Z # latent of transformer
53
+
54
+ class TransformerDecoder(nn.Module):
55
+ def __init__(self, d_model=256, nhead=8, num_layers=12, dim_feedforward=1024, dropout=0.1):
56
+ super().__init__()
57
+ self.d_model = d_model
58
+
59
+ # Enhanced positional encoding
60
+ self.positional_encoding = PositionalEncoding(d_model)
61
+
62
+ # Multi-layer learnable start tokens
63
+ self.base_start = nn.Parameter(torch.randn(1, 1024, d_model))
64
+ self.start_net = nn.Sequential(
65
+ nn.LayerNorm(d_model),
66
+ nn.Linear(d_model, dim_feedforward),
67
+ nn.GELU(),
68
+ nn.Dropout(dropout),
69
+ nn.Linear(dim_feedforward, d_model),
70
+ nn.LayerNorm(d_model)
71
+ )
72
+
73
+ # Context-aware transformer decoder
74
+ self.decoder = nn.TransformerDecoder(
75
+ nn.TransformerDecoderLayer(
76
+ d_model=d_model,
77
+ nhead=nhead,
78
+ dim_feedforward=dim_feedforward,
79
+ dropout=dropout,
80
+ batch_first=True
81
+ ),
82
+ num_layers=num_layers
83
+ )
84
+
85
+ # Output projection with residual
86
+ self.output_layer = nn.Sequential(
87
+ nn.Linear(d_model, d_model*2),
88
+ nn.GELU(),
89
+ nn.Linear(d_model*2, d_model))
90
+
91
+ self.init_weights()
92
+
93
+ def init_weights(self):
94
+ for p in self.parameters():
95
+ if p.dim() > 1:
96
+ nn.init.xavier_uniform_(p)
97
+ nn.init.normal_(self.base_start, mean=0, std=0.02)
98
+
99
+ def preprocess_latent(self, Z):
100
+ # Convert (B, C, H, W) to (B, H*W, C)
101
+ return Z.permute(0, 2, 3, 1).flatten(1, 2)
102
+
103
+ def postprocess_latent(self, Z):
104
+ # Convert (B, H*W, C) back to (B, C, H, W)
105
+ B, L, C = Z.shape
106
+ H = W = int(L**0.5)
107
+ return Z.view(B, H, W, C).permute(0, 3, 1, 2)
108
+
109
+ def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
110
+ # Process input latent
111
+ Z = self.preprocess_latent(Z)
112
+ #Z = self.positional_encoding(Z)
113
+
114
+ # Generate enhanced start tokens
115
+ B = Z.size(0)
116
+ base_tokens = self.base_start.expand(B, -1, -1)
117
+ processed_start = self.start_net(base_tokens)
118
+
119
+ # Teacher forcing integration
120
+ if Z1_start_tokens is not None and teacher_forcing_ratio > 0:
121
+ Z1_processed = self.positional_encoding(self.preprocess_latent(Z1_start_tokens))
122
+
123
+ # Create mixing mask
124
+ mask = torch.rand(B, 1, 1, device=Z.device) < teacher_forcing_ratio
125
+ processed_start = torch.where(mask, Z1_processed, processed_start)
126
+
127
+ # Decoder processing with residual
128
+ decoder_input = self.positional_encoding(processed_start)
129
+ outputs = self.decoder(decoder_input, Z)
130
+ outputs = self.output_layer(outputs + decoder_input)
131
+
132
+ return self.postprocess_latent(outputs)
133
+
134
+ class DeepfakeToSourceTransformer(nn.Module):
135
+ def __init__(self, d_model=256, encoder_nhead=8, decoder_nhead=8, num_encoder_layers=6, num_decoder_layers=12, dim_feedforward=1024, dropout=0.1):
136
+ super().__init__()
137
+ self.encoder = TransformerEncoder(
138
+ d_model=d_model,
139
+ nhead=encoder_nhead,
140
+ num_layers=num_encoder_layers,
141
+ dim_feedforward=1024,
142
+ dropout=dropout
143
+ )
144
+ self.decoder = TransformerDecoder(
145
+ d_model=d_model,
146
+ nhead=decoder_nhead,
147
+ num_layers=num_decoder_layers,
148
+ dim_feedforward=dim_feedforward,
149
+ dropout=dropout
150
+ )
151
+
152
+ def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
153
+ memory = self.encoder(Z)
154
+ Z1 = self.decoder(memory, Z1_start_tokens, teacher_forcing_ratio=teacher_forcing_ratio)
155
+ return Z1
modules/.ipynb_checkpoints/segmentface-checkpoint.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import mediapipe as mp
3
+ import numpy as np
4
+ from rembg import remove
5
+ from PIL import Image
6
+
7
+ class FaceSegmenter:
8
+ def __init__(self, threshold=0.5):
9
+ self.threshold = threshold
10
+ # Initialize face detection
11
+ self.face_detection = mp.solutions.face_detection.FaceDetection(
12
+ model_selection=1, # 1 for general use, 0 for close-up faces
13
+ min_detection_confidence=0.5
14
+ )
15
+ # Initialize selfie segmentation (for background removal)
16
+ self.selfie_segmentation = mp.solutions.selfie_segmentation.SelfieSegmentation(
17
+ model_selection=1 # 1 for general use, 0 for close-up faces
18
+ )
19
+
20
+ def segment_face(self, image_path):
21
+ # Load the image
22
+ image = cv2.imread(image_path)
23
+ if image is None:
24
+ raise ValueError("Image not found or unable to load.")
25
+
26
+ # Convert to RGB (MediaPipe requires RGB input)
27
+ rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
28
+
29
+ # Step 1: Detect the face
30
+ face_results = self.face_detection.process(rgb_image)
31
+ if not face_results.detections:
32
+ # Use rembg to remove the background
33
+ with open(image_path, "rb") as input_file:
34
+ input_image = input_file.read()
35
+ output_image = remove(input_image)
36
+ # Convert the output image to a numpy array
37
+ output_image = np.array(Image.open(io.BytesIO(output_image)))
38
+ # Convert RGBA to RGB (remove alpha channel)
39
+ if output_image.shape[2] == 4:
40
+ output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
41
+ return output_image
42
+
43
+ # Get the bounding box of the first detected face
44
+ detection = face_results.detections[0]
45
+ bboxC = detection.location_data.relative_bounding_box
46
+ h, w, _ = image.shape
47
+ x, y, width, height = int(bboxC.xmin * w), int(bboxC.ymin * h), \
48
+ int(bboxC.width * w), int(bboxC.height * h)
49
+
50
+ # Step 2: Segment the foreground (selfie segmentation)
51
+ segmentation_results = self.selfie_segmentation.process(rgb_image)
52
+ if segmentation_results.segmentation_mask is None:
53
+ raise ValueError("Segmentation failed.")
54
+
55
+ # Create a binary mask
56
+ mask = (segmentation_results.segmentation_mask > self.threshold).astype(np.uint8)
57
+
58
+ # Step 3: Crop the face using the bounding box
59
+ face_mask = np.zeros_like(mask)
60
+ face_mask[y:y+height, x:x+width] = mask[y:y+height, x:x+width]
61
+
62
+ # Apply the mask to the original image
63
+ segmented_face = cv2.bitwise_and(image, image, mask=face_mask)
64
+
65
+ return segmented_face
66
+
67
+ def save_segmented_face(self, image_path, output_path):
68
+ segmented_face = self.segment_face(image_path)
69
+ cv2.imwrite(output_path, segmented_face)
70
+
71
+ def show_segmented_face(self, image_path):
72
+ segmented_face = self.segment_face(image_path)
73
+ cv2.imshow("Segmented Face", segmented_face)
74
+ cv2.waitKey(0)
75
+ cv2.destroyAllWindows()
modules/denormalize.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+ #------------------Denormalization---------------------------------------------
5
+ def denormalize_bin(tensor):
6
+ tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
7
+ tr = tr.add(1).div(2) # Shift to [0, 1]
8
+ return tr
9
+
10
+ def denormalize_tr(tensor):
11
+ tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
12
+ tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255]
13
+ tr = tr.byte() # Convert the tensor to uint8
14
+ return tr
15
+
16
+ def denormalize_ar(tensor):
17
+ tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
18
+ tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255]
19
+ tr = tr.byte() # Convert the tensor to uint8
20
+ arr = tr.permute(0, 2, 3, 1).cpu().detach().numpy() # Convert to (N, H, W, C) and numpy array
21
+ return arr
modules/finetunedvqgan.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.checkpoint import checkpoint
3
+ from taming.models.vqgan import VQModel
4
+ from omegaconf import OmegaConf
5
+ from taming.models.vqgan import GumbelVQ
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ class Generator:
10
+ def __init__(self, config_path, device=device):
11
+ self.config_path = config_path
12
+ self.device = device
13
+
14
+ def load_models(self):
15
+ # Load configuration
16
+ config = OmegaConf.load(self.config_path)
17
+ # Extract parameters specific to GumbelVQ
18
+ vq_params = config.model.params
19
+ # Initialize the GumbelVQ models
20
+ model_vaq = GumbelVQ(
21
+ ddconfig=vq_params.ddconfig,
22
+ lossconfig=vq_params.lossconfig,
23
+ n_embed=vq_params.n_embed,
24
+ embed_dim=vq_params.embed_dim,
25
+ kl_weight=vq_params.kl_weight,
26
+ temperature_scheduler_config=vq_params.temperature_scheduler_config,
27
+ ).to(self.device)
28
+
29
+ return model_vaq
30
+
31
+
modules/frameworkeval.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import vgg16
5
+ from torchmetrics.functional import structural_similarity_index_measure
6
+ from facenet_pytorch import InceptionResnetV1
7
+ from denormalize import denormalize_bin, denormalize_tr, denormalize_ar
8
+
9
+ class DF(nn.Module):
10
+ def __init__(self):
11
+ super(DF, self).__init__()
12
+ self.mse_weight = 0.25
13
+ self.perceptual_weight = 0.25
14
+ self.ssim_weight = 0.25
15
+ self.idsim_weight = 0.25
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
18
+ self.facenet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
19
+ for param in self.facenet.parameters():
20
+ param.requires_grad = False # Freeze the model
21
+ self.cosloss = nn.CosineEmbeddingLoss()
22
+
23
+ def perceptual_loss(self, real, fake):
24
+ with torch.no_grad(): # VGG is frozen during training
25
+ real_features = self.vgg(real)
26
+ fake_features = self.vgg(fake)
27
+ return F.mse_loss(real_features, fake_features)
28
+
29
+ def idsimilarity(self, real, fake):
30
+ with torch.no_grad():
31
+ # Extract embeddings
32
+ input_embed = self.facenet(real).to(device)
33
+ generated_embed = self.facenet(fake).to(device)
34
+ # Compute cosine similarity loss
35
+ target = torch.ones(input_embed.size(0)).to(real.device) # Target = 1 (maximize similarity)
36
+ return self.cosloss(input_embed, generated_embed, target)
37
+
38
+ def forward(self, r, f):
39
+ real = denormalize_bin(r) #[-1,1] to [0,1]
40
+ fake = denormalize_bin(f)
41
+ mse_loss = F.mse_loss(real, fake)
42
+ perceptual_loss = self.perceptual_loss(real, fake)
43
+ idsim_loss = self.idsimilarity(real, fake)
44
+ ssim = structural_similarity_index_measure(fake, real)
45
+ ssim_loss = 1 - ssim
46
+ id_si = 1 - idsim_loss
47
+
48
+ total_loss = (self.mse_weight * mse_loss) + (self.perceptual_weight * perceptual_loss) + (self.idsim_weight * idsim_loss) + (self.ssim_weight * ssim_loss)
49
+ components = {
50
+ "MSE Loss": mse_loss.item(),
51
+ "Perceptual Loss": perceptual_loss.item(),
52
+ "ID-SIM Loss": idsim_loss.item(),
53
+ "SSIM Loss": ssim_loss.item()
54
+ }
55
+
56
+ return total_loss, components
modules/modelz.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ ##_____________________Define:MODEL-F & MODEL-G_________________
8
+
9
+ # Positional Encoding
10
+ class PositionalEncoding(nn.Module):
11
+ def __init__(self, d_model, max_len=1024):
12
+ super(PositionalEncoding, self).__init__()
13
+ self.dropout = nn.Dropout(0.1)
14
+ position = torch.arange(max_len).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16
+ pe = torch.zeros(max_len, d_model)
17
+ pe[:, 0::2] = torch.sin(position * div_term)
18
+ pe[:, 1::2] = torch.cos(position * div_term)
19
+ self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
20
+
21
+ def forward(self, x):
22
+ x = x + self.pe[:, :x.size(1)]
23
+ return self.dropout(x)
24
+
25
+ # Transformer Encoder
26
+ class TransformerEncoder(nn.Module):
27
+ def __init__(self, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1):
28
+ super(TransformerEncoder, self).__init__()
29
+ self.positional_encoding = PositionalEncoding(d_model)
30
+ self.encoder = nn.TransformerEncoder(
31
+ nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout,batch_first=True),
32
+ num_layers=num_layers
33
+ )
34
+
35
+ def preprocess_latent(self, Z):
36
+ batch_size, channels, height, width = Z.shape # (batch_size, 256, 32, 32)
37
+ seq_len = height * width
38
+ Z = Z.permute(0, 2, 3, 1).reshape(batch_size, seq_len, channels) # (batch_size, 1024, 256)
39
+ return Z
40
+
41
+ def postprocess_latent(self, Z):
42
+ batch_size, seq_len, channels = Z.shape # (batch_size, 1024, 256)
43
+ height = width = int(math.sqrt(seq_len))
44
+ Z = Z.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) # (batch_size, 256, 32, 32)
45
+ return Z
46
+
47
+ def forward(self, Z):
48
+ Z = self.preprocess_latent(Z)
49
+ Z = self.positional_encoding(Z)
50
+ Z = self.encoder(Z)
51
+ Z = self.postprocess_latent(Z)
52
+ return Z # latent of transformer
53
+
54
+ class TransformerDecoder(nn.Module):
55
+ def __init__(self, d_model=256, nhead=8, num_layers=12, dim_feedforward=1024, dropout=0.1):
56
+ super().__init__()
57
+ self.d_model = d_model
58
+
59
+ # Enhanced positional encoding
60
+ self.positional_encoding = PositionalEncoding(d_model)
61
+
62
+ # Multi-layer learnable start tokens
63
+ self.base_start = nn.Parameter(torch.randn(1, 1024, d_model))
64
+ self.start_net = nn.Sequential(
65
+ nn.LayerNorm(d_model),
66
+ nn.Linear(d_model, dim_feedforward),
67
+ nn.GELU(),
68
+ nn.Dropout(dropout),
69
+ nn.Linear(dim_feedforward, d_model),
70
+ nn.LayerNorm(d_model)
71
+ )
72
+
73
+ # Context-aware transformer decoder
74
+ self.decoder = nn.TransformerDecoder(
75
+ nn.TransformerDecoderLayer(
76
+ d_model=d_model,
77
+ nhead=nhead,
78
+ dim_feedforward=dim_feedforward,
79
+ dropout=dropout,
80
+ batch_first=True
81
+ ),
82
+ num_layers=num_layers
83
+ )
84
+
85
+ # Output projection with residual
86
+ self.output_layer = nn.Sequential(
87
+ nn.Linear(d_model, d_model*2),
88
+ nn.GELU(),
89
+ nn.Linear(d_model*2, d_model))
90
+
91
+ self.init_weights()
92
+
93
+ def init_weights(self):
94
+ for p in self.parameters():
95
+ if p.dim() > 1:
96
+ nn.init.xavier_uniform_(p)
97
+ nn.init.normal_(self.base_start, mean=0, std=0.02)
98
+
99
+ def preprocess_latent(self, Z):
100
+ # Convert (B, C, H, W) to (B, H*W, C)
101
+ return Z.permute(0, 2, 3, 1).flatten(1, 2)
102
+
103
+ def postprocess_latent(self, Z):
104
+ # Convert (B, H*W, C) back to (B, C, H, W)
105
+ B, L, C = Z.shape
106
+ H = W = int(L**0.5)
107
+ return Z.view(B, H, W, C).permute(0, 3, 1, 2)
108
+
109
+ def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
110
+ # Process input latent
111
+ Z = self.preprocess_latent(Z)
112
+ #Z = self.positional_encoding(Z)
113
+
114
+ # Generate enhanced start tokens
115
+ B = Z.size(0)
116
+ base_tokens = self.base_start.expand(B, -1, -1)
117
+ processed_start = self.start_net(base_tokens)
118
+
119
+ # Teacher forcing integration
120
+ if Z1_start_tokens is not None and teacher_forcing_ratio > 0:
121
+ Z1_processed = self.positional_encoding(self.preprocess_latent(Z1_start_tokens))
122
+
123
+ # Create mixing mask
124
+ mask = torch.rand(B, 1, 1, device=Z.device) < teacher_forcing_ratio
125
+ processed_start = torch.where(mask, Z1_processed, processed_start)
126
+
127
+ # Decoder processing with residual
128
+ decoder_input = self.positional_encoding(processed_start)
129
+ outputs = self.decoder(decoder_input, Z)
130
+ outputs = self.output_layer(outputs + decoder_input)
131
+
132
+ return self.postprocess_latent(outputs)
133
+
134
+ class DeepfakeToSourceTransformer(nn.Module):
135
+ def __init__(self, d_model=256, encoder_nhead=8, decoder_nhead=8, num_encoder_layers=6, num_decoder_layers=12, dim_feedforward=1024, dropout=0.1):
136
+ super().__init__()
137
+ self.encoder = TransformerEncoder(
138
+ d_model=d_model,
139
+ nhead=encoder_nhead,
140
+ num_layers=num_encoder_layers,
141
+ dim_feedforward=1024,
142
+ dropout=dropout
143
+ )
144
+ self.decoder = TransformerDecoder(
145
+ d_model=d_model,
146
+ nhead=decoder_nhead,
147
+ num_layers=num_decoder_layers,
148
+ dim_feedforward=dim_feedforward,
149
+ dropout=dropout
150
+ )
151
+
152
+ def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
153
+ memory = self.encoder(Z)
154
+ Z1 = self.decoder(memory, Z1_start_tokens, teacher_forcing_ratio=teacher_forcing_ratio)
155
+ return Z1
modules/segmentface.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import mediapipe as mp
3
+ import numpy as np
4
+ from rembg import remove
5
+ from PIL import Image
6
+
7
+ class FaceSegmenter:
8
+ def __init__(self, threshold=0.5):
9
+ self.threshold = threshold
10
+ # Initialize face detection
11
+ self.face_detection = mp.solutions.face_detection.FaceDetection(
12
+ model_selection=1, # 1 for general use, 0 for close-up faces
13
+ min_detection_confidence=0.5
14
+ )
15
+ # Initialize selfie segmentation (for background removal)
16
+ self.selfie_segmentation = mp.solutions.selfie_segmentation.SelfieSegmentation(
17
+ model_selection=1 # 1 for general use, 0 for close-up faces
18
+ )
19
+
20
+ def segment_face(self, image_path):
21
+ # Load the image
22
+ image = cv2.imread(image_path)
23
+ if image is None:
24
+ raise ValueError("Image not found or unable to load.")
25
+
26
+ # Convert to RGB (MediaPipe requires RGB input)
27
+ rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
28
+
29
+ # Step 1: Detect the face
30
+ face_results = self.face_detection.process(rgb_image)
31
+ if not face_results.detections:
32
+ # Use rembg to remove the background
33
+ with open(image_path, "rb") as input_file:
34
+ input_image = input_file.read()
35
+ output_image = remove(input_image)
36
+ # Convert the output image to a numpy array
37
+ output_image = np.array(Image.open(io.BytesIO(output_image)))
38
+ # Convert RGBA to RGB (remove alpha channel)
39
+ if output_image.shape[2] == 4:
40
+ output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
41
+ return output_image
42
+
43
+ # Get the bounding box of the first detected face
44
+ detection = face_results.detections[0]
45
+ bboxC = detection.location_data.relative_bounding_box
46
+ h, w, _ = image.shape
47
+ x, y, width, height = int(bboxC.xmin * w), int(bboxC.ymin * h), \
48
+ int(bboxC.width * w), int(bboxC.height * h)
49
+
50
+ # Step 2: Segment the foreground (selfie segmentation)
51
+ segmentation_results = self.selfie_segmentation.process(rgb_image)
52
+ if segmentation_results.segmentation_mask is None:
53
+ raise ValueError("Segmentation failed.")
54
+
55
+ # Create a binary mask
56
+ mask = (segmentation_results.segmentation_mask > self.threshold).astype(np.uint8)
57
+
58
+ # Step 3: Crop the face using the bounding box
59
+ face_mask = np.zeros_like(mask)
60
+ face_mask[y:y+height, x:x+width] = mask[y:y+height, x:x+width]
61
+
62
+ # Apply the mask to the original image
63
+ segmented_face = cv2.bitwise_and(image, image, mask=face_mask)
64
+
65
+ return segmented_face
66
+
67
+ def save_segmented_face(self, image_path, output_path):
68
+ segmented_face = self.segment_face(image_path)
69
+ cv2.imwrite(output_path, segmented_face)
70
+
71
+ def show_segmented_face(self, image_path):
72
+ segmented_face = self.segment_face(image_path)
73
+ cv2.imshow("Segmented Face", segmented_face)
74
+ cv2.waitKey(0)
75
+ cv2.destroyAllWindows()