Jannat24 commited on
Commit
da52bc4
·
1 Parent(s): fb1bd05

16march2025

Browse files
Files changed (1) hide show
  1. app.py +91 -179
app.py CHANGED
@@ -2,215 +2,127 @@ import io
2
  import os
3
  import shutil
4
  import requests
5
- import time
6
  import numpy as np
7
  from PIL import Image, ImageOps
8
- from math import nan
9
  import math
 
10
  import pickle
11
- import warnings
12
- warnings.filterwarnings("ignore")
13
-
14
  import torch
15
  import torch.nn as nn
16
  import torch.nn.functional as F
17
- import torch.optim as optim
18
-
19
- from torch.utils.data import Dataset, ConcatDataset, DataLoader
20
- from torchvision.datasets import ImageFolder
21
  import torchvision.transforms as T
22
  import torchvision.transforms.functional as TF
23
- from torch.cuda.amp import autocast, GradScaler
24
- import jax
25
- import jax.numpy as jnp
26
-
27
- import transformers
28
- from transformers.modeling_flax_utils import FlaxPreTrainedModel
29
- from vqgan_jax.modeling_flax_vqgan import VQModel
30
-
31
  import gradio as gr
 
 
 
 
 
32
 
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
 
 
35
 
36
- class Model_Z1(nn.Module):
37
- def __init__(self):
38
- super(Model_Z1, self).__init__()
39
- self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1)
40
- self.batchnorm = nn.BatchNorm2d(2048)
41
- self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1)
42
- self.batchnorm2 = nn.BatchNorm2d(256)
43
- self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1)
44
- self.batchnorm3 = nn.BatchNorm2d(1024)
45
- self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1)
46
- self.batchnorm4 = nn.BatchNorm2d(256)
47
- self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
48
- self.batchnorm5 = nn.BatchNorm2d(512)
49
- self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
50
- self.elu = nn.ELU()
51
-
52
- def forward(self, x):
53
- res = x
54
- x = self.elu(self.conv1(x))
55
- x = self.batchnorm(x)
56
- x = self.elu(self.conv2(x)) + res
57
- x = self.batchnorm2(x)
58
- x = self.elu(self.conv3(x))
59
- x = self.batchnorm3(x)
60
- x = self.elu(self.conv4(x)) + res
61
- x = self.batchnorm4(x)
62
- x = self.elu(self.conv5(x))
63
- x = self.batchnorm5(x)
64
- out = self.elu(self.conv6(x)) + res
65
- return out
66
-
67
- class Model_Z(nn.Module):
68
- def __init__(self):
69
- super(Model_Z, self).__init__()
70
- self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1)
71
- self.batchnorm = nn.BatchNorm2d(2048)
72
- self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1)
73
- self.batchnorm2 = nn.BatchNorm2d(256)
74
- self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1)
75
- self.batchnorm3 = nn.BatchNorm2d(1024)
76
- self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1)
77
- self.batchnorm4 = nn.BatchNorm2d(256)
78
- self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
79
- self.batchnorm5 = nn.BatchNorm2d(512)
80
- self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
81
- self.batchnorm6 = nn.BatchNorm2d(256)
82
- self.conv7 = nn.Conv2d(in_channels=256, out_channels=448, kernel_size=3, padding=1)
83
- self.batchnorm7 = nn.BatchNorm2d(448)
84
- self.conv8 = nn.Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1)
85
- self.batchnorm8 = nn.BatchNorm2d(384)
86
- self.conv9 = nn.Conv2d(in_channels=384, out_channels=320, kernel_size=3, padding=1)
87
- self.batchnorm9 = nn.BatchNorm2d(320)
88
- self.conv10 = nn.Conv2d(in_channels=320, out_channels=256, kernel_size=3, padding=1)
89
- self.elu = nn.ELU()
90
-
91
- def forward(self, x):
92
- res = x
93
- x = self.elu(self.conv1(x))
94
- x = self.batchnorm(x)
95
- x = self.elu(self.conv2(x)) + res
96
- x = self.batchnorm2(x)
97
- x = self.elu(self.conv3(x))
98
- x = self.batchnorm3(x)
99
- x = self.elu(self.conv4(x)) + res
100
- x = self.batchnorm4(x)
101
- x = self.elu(self.conv5(x))
102
- x = self.batchnorm5(x)
103
- x = self.elu(self.conv6(x)) + res
104
- x = self.batchnorm6(x)
105
- x = self.elu(self.conv7(x))
106
- x = self.batchnorm7(x)
107
- x = self.elu(self.conv8(x))
108
- x = self.batchnorm8(x)
109
- x = self.elu(self.conv9(x))
110
- x = self.batchnorm9(x)
111
- out = self.elu(self.conv10(x)) + res
112
- return out
113
-
114
-
115
- def tensor_jax(x):
116
- if x.dim() == 3:
117
- x = x. unsqueeze(0)
118
-
119
- x_np = x.detach().permute(0, 2, 3, 1).cpu().numpy() # Convert from (N, C, H, W) to (N, H, W, C) and move to CPU
120
- x_jax = jnp.array(x_np)
121
- return x_jax
122
-
123
- def jax_to_tensor(x):
124
- x_tensor = torch.tensor(np.array(x),requires_grad=True).permute(0, 3, 1, 2).to(device) # Convert from (N, H, W, C) to (N, C, H, W)
125
- return x_tensor
126
-
127
- # Define the transform
128
  transform = T.Compose([
129
- T.Resize((256, 256)),
130
- T.ToTensor()
131
- ])
132
-
133
- def gen_sources(img):
134
- model_name = "dalle-mini/vqgan_imagenet_f16_16384"
135
- model_vaq = VQModel.from_pretrained(model_name)
136
-
137
- model_z1 = Model_Z1()
138
- model_z1 = model_z1.to(device)
139
- model_z1.load_state_dict(torch.load("./model_z1.pth",map_location=device))
140
-
141
- model_z2 = Model_Z()
142
- model_z2 = model_z2.to(device)
143
- model_z2.load_state_dict(torch.load("./model_z2.pth",map_location=device))
144
-
145
- model_zdf = Model_Z()
146
- model_zdf = model_zdf.to(device)
147
- model_zdf.load_state_dict(torch.load("./model_zdf.pth",map_location=device))
148
-
149
- criterion = nn.MSELoss()
 
 
 
 
150
  model_z1.eval()
 
 
 
151
  model_z2.eval()
152
- model_zdf.eval()
153
-
 
 
 
 
154
  with torch.no_grad():
155
- img = img.convert('RGB')
156
- df_img = transform(img)
157
- df_img = df_img.unsqueeze(0) # Change shape to (1, 3, 256, 256)
158
- df_img = df_img.to(device)
159
- #convert images: tensor --> jax_array
160
- df_img_jax = tensor_jax(df_img)
161
- #calculate quantized_code(z) for all images
162
- z_df,_ = model_vaq.encode(df_img_jax)
163
- #convert quantized_code(z): jax_array --> tensor
164
- z_df_tensor = jax_to_tensor(z_df)
165
- ##----------------------------------------------------------------------
166
- ##----------------------model_z1-----------------------
167
- outputs_z1 = model_z1(z_df_tensor)
168
- #generate img1
169
- z1_rec_jax = tensor_jax(outputs_z1)
170
- rec_img1 = model_vaq.decode(z1_rec_jax)
171
- ##----------------------------------------------------------------------
172
- ##----------------------model_z2-----------------------
173
- outputs_z2 = model_z2(z_df_tensor)
174
- #generate img2
175
- z2_rec_jax = tensor_jax(outputs_z2)
176
- rec_img2 = model_vaq.decode(z2_rec_jax)
177
- ##----------------------------------------------------------------------
178
- ##----------------------model_zdf-----------------------
179
- z_rec = outputs_z1 + outputs_z2
180
- outputs_zdf = model_zdf(z_rec)
181
- lossdf = criterion(outputs_zdf, z_df_tensor)
182
- #calculate dfimg reconstruction loss
183
- zdf_rec_jax = tensor_jax(outputs_zdf)
184
- rec_df = model_vaq.decode(zdf_rec_jax)
185
- rec_df_tensor = jax_to_tensor(rec_df)
186
- dfimgloss = criterion(rec_df_tensor, df_img)
187
- # Convert tensor back to a PIL image
188
- rec_img1 = jax_to_tensor(rec_img1)
189
  rec_img1 = rec_img1.squeeze(0)
190
- rec_img2 = jax_to_tensor(rec_img2)
191
  rec_img2 = rec_img2.squeeze(0)
192
- rec_df = jax_to_tensor(rec_df)
193
- rec_df = rec_df.squeeze(0)
194
  rec_img1_pil = T.ToPILImage()(rec_img1)
195
  rec_img2_pil = T.ToPILImage()(rec_img2)
196
- rec_df_pil = T.ToPILImage()(rec_df)
197
-
198
- return (rec_img1_pil, rec_img2_pil, round(dfimgloss.item(),3))
199
 
200
- # Create the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  interface = gr.Interface(
202
  fn=gen_sources,
203
  inputs=gr.Image(type="pil", label="Input Image"),
204
  outputs=[
205
- gr.Image(type="pil", label="Source Image 1"),
206
- gr.Image(type="pil", label="Source Image 2"),
207
- #gr.Image(type="pil", label="Deepfake Image"),
208
  gr.Number(label="Reconstruction Loss")
209
  ],
210
  examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"],
211
  theme = gr.themes.Soft(),
212
- title="Uncovering Deepfake Image",
213
- description="Upload an image.",
214
  )
215
 
216
- interface.launch()
 
2
  import os
3
  import shutil
4
  import requests
 
5
  import numpy as np
6
  from PIL import Image, ImageOps
 
7
  import math
8
+ import matplotlib.pyplot as plt
9
  import pickle
 
 
 
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
 
 
 
 
13
  import torchvision.transforms as T
14
  import torchvision.transforms.functional as TF
15
+ from torch.utils.checkpoint import checkpoint
16
+ from torchvision.models import vgg16
17
+ from torchmetrics.image.fid import FrechetInceptionDistance
18
+ from torchmetrics.functional import structural_similarity_index_measure
19
+ from facenet_pytorch import InceptionResnetV1
20
+ from taming.models.vqgan import VQModel
21
+ from omegaconf import OmegaConf
22
+ from taming.models.vqgan import GumbelVQ
23
  import gradio as gr
24
+ from finetunedvqgan import Generator
25
+ from modelz import DeepfakeToSourceTransformer
26
+ from frameworkeval import DF
27
+ from segmentface import FaceSegmenter
28
+ from denormalize import denormalize_bin, denormalize_tr, denormalize_ar
29
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
+ ##________________________Transformation______________________________
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  transform = T.Compose([
35
+ T.Resize((256, 256)),
36
+ T.ToTensor(),
37
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) # Normalize to [-1, 1]
38
+
39
+ #_________________Define:Gradio Function________________________
40
+
41
+ def gen_sources(deepfake_img):
42
+ #----------------DeepFake Face Segmentation-----------------
43
+ deepfake_seg = segmenter.segment_face(deepfake_img)
44
+ config_path = "./config.yaml"
45
+ #------------Initialize:Decoder-F------------------------
46
+ checkpoint_path_f = "./model_vaq1_ff.pth"
47
+ checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
48
+ model_vaq_f = Generator(config_path)
49
+ model_vaq_f = model_vaq_f.load_state_dict(checkpoint_f, strict=True)
50
+ model_vaq_f.eval()
51
+ #------------Initialize:Decoder-G------------------------
52
+ checkpoint_path_g = "./model_vaq2_gg.pth"
53
+ checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
54
+ model_vaq_g = Generator(config_path)
55
+ model_vaq_g = model_vaq_g.load_state_dict(checkpoint_g, strict=True)
56
+ model_vaq_g.eval()
57
+ ##------------------------Initialize Model-F-------------------------------------
58
+ model_z1 = DeepfakeToSourceTransformer().to(device)
59
+ model_z1.load_state_dict(torch.load("./model_z1_ff.pth",map_location=device),strict=True)
60
  model_z1.eval()
61
+ ##------------------------Initialize Model-G-------------------------------------
62
+ model_z2 = DeepfakeToSourceTransformer().to(device)
63
+ model_z2.load_state_dict(torch.load("./model_z2_gg.pth",map_location=device),strict=True)
64
  model_z2.eval()
65
+ ##--------------------Initialize:Evaluation---------------------------------------
66
+ criterion = DF()
67
+ ##----------------------Initialize:Face Segmentation----------------------------------
68
+ segmenter = FaceSegmenter(threshold=0.5)
69
+
70
+ ##----------------------Operation-------------------------------------------------
71
  with torch.no_grad():
72
+ # Load and preprocess input image
73
+ img = Image.open(deepfake_img).convert('RGB')
74
+ segimg = Image.open(deepfake_seg).convert('RGB')
75
+ df_img = transform(img).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
76
+ seg_img = transform(segimg).unsqueeze(0).to(device)
77
+
78
+ # Calculate quantized_block for all images
79
+ z_df, _, _ = model_vaq_f.encode(df_img)
80
+ z_seg, _, _ = model_vaq_g.encode(seg_img)
81
+ rec_z_img1 = model_z1(z_df)
82
+ rec_z_img2 = model_z2(z_seg)
83
+ rec_img1 = model_vaq_f.decode(rec_z_img1)
84
+ rec_img2 = model_vaq_g.decode(rec_z_img2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  rec_img1 = rec_img1.squeeze(0)
 
86
  rec_img2 = rec_img2.squeeze(0)
 
 
87
  rec_img1_pil = T.ToPILImage()(rec_img1)
88
  rec_img2_pil = T.ToPILImage()(rec_img2)
 
 
 
89
 
90
+ # Save PIL images to in-memory buffers
91
+ buffer1 = BytesIO()
92
+ buffer2 = BytesIO()
93
+ rec_img1_pil.save(buffer1, format="PNG")
94
+ rec_img2_pil.save(buffer2, format="PNG")
95
+
96
+ # Pass buffers to Gradio client
97
+ result = client.predict(
98
+ target=file(buffer1),
99
+ source=file(buffer2), slider=100, adv_slider=100,
100
+ settings=["Adversarial Defense"], api_name="/run_inference"
101
+ )
102
+
103
+ # Load result and compute loss
104
+ dfimage_pil = Image.open(result) # Open the resulting image
105
+ buffer3 = BytesIO()
106
+ dfimage_pil.save(buffer3, format="PNG")
107
+ rec_df = transform(Image.open(buffer3)).unsqueeze(0).to(device)
108
+ rec_loss,_ = criterion(df_img, rec_df)
109
+
110
+ return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3))
111
+
112
+ #________________________Create the Gradio interface_________________________________
113
  interface = gr.Interface(
114
  fn=gen_sources,
115
  inputs=gr.Image(type="pil", label="Input Image"),
116
  outputs=[
117
+ gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"),
118
+ gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"),
119
+ gr.Image(type="pil", label="Reconstructed Deepfake Image"),
120
  gr.Number(label="Reconstruction Loss")
121
  ],
122
  examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"],
123
  theme = gr.themes.Soft(),
124
+ title="Uncovering Deepfake Image for Identifying Source Images",
125
+ description="Upload an DeepFake image.",
126
  )
127
 
128
+ interface.launch()