Jannat24 commited on
Commit
cae99db
·
1 Parent(s): 973cb90
Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import shutil
4
+ import requests
5
+ import time
6
+ import numpy as np
7
+ from PIL import Image, ImageOps
8
+ from math import nan
9
+ import math
10
+ import matplotlib.pyplot as plt
11
+ import pickle
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+
20
+ from torch.utils.data import Dataset, ConcatDataset, DataLoader
21
+ from torchvision.datasets import ImageFolder
22
+ import torchvision.transforms as T
23
+ import torchvision.transforms.functional as TF
24
+ from torch.cuda.amp import autocast, GradScaler
25
+ import jax
26
+ import jax.numpy as jnp
27
+
28
+ import transformers
29
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
30
+ from vqgan_jax.modeling_flax_vqgan import VQModel
31
+
32
+ import gradio as gr
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+
37
+ class Model_Z1(nn.Module):
38
+ def __init__(self):
39
+ super(Model_Z1, self).__init__()
40
+ self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1)
41
+ self.batchnorm = nn.BatchNorm2d(2048)
42
+ self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1)
43
+ self.batchnorm2 = nn.BatchNorm2d(256)
44
+ self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1)
45
+ self.batchnorm3 = nn.BatchNorm2d(1024)
46
+ self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1)
47
+ self.batchnorm4 = nn.BatchNorm2d(256)
48
+ self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
49
+ self.batchnorm5 = nn.BatchNorm2d(512)
50
+ self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
51
+ self.elu = nn.ELU()
52
+
53
+ def forward(self, x):
54
+ res = x
55
+ x = self.elu(self.conv1(x))
56
+ x = self.batchnorm(x)
57
+ x = self.elu(self.conv2(x)) + res
58
+ x = self.batchnorm2(x)
59
+ x = self.elu(self.conv3(x))
60
+ x = self.batchnorm3(x)
61
+ x = self.elu(self.conv4(x)) + res
62
+ x = self.batchnorm4(x)
63
+ x = self.elu(self.conv5(x))
64
+ x = self.batchnorm5(x)
65
+ out = self.elu(self.conv6(x)) + res
66
+ return out
67
+
68
+ class Model_Z(nn.Module):
69
+ def __init__(self):
70
+ super(Model_Z, self).__init__()
71
+ self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1)
72
+ self.batchnorm = nn.BatchNorm2d(2048)
73
+ self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1)
74
+ self.batchnorm2 = nn.BatchNorm2d(256)
75
+ self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1)
76
+ self.batchnorm3 = nn.BatchNorm2d(1024)
77
+ self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1)
78
+ self.batchnorm4 = nn.BatchNorm2d(256)
79
+ self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
80
+ self.batchnorm5 = nn.BatchNorm2d(512)
81
+ self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
82
+ self.batchnorm6 = nn.BatchNorm2d(256)
83
+ self.conv7 = nn.Conv2d(in_channels=256, out_channels=448, kernel_size=3, padding=1)
84
+ self.batchnorm7 = nn.BatchNorm2d(448)
85
+ self.conv8 = nn.Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1)
86
+ self.batchnorm8 = nn.BatchNorm2d(384)
87
+ self.conv9 = nn.Conv2d(in_channels=384, out_channels=320, kernel_size=3, padding=1)
88
+ self.batchnorm9 = nn.BatchNorm2d(320)
89
+ self.conv10 = nn.Conv2d(in_channels=320, out_channels=256, kernel_size=3, padding=1)
90
+ self.elu = nn.ELU()
91
+
92
+ def forward(self, x):
93
+ res = x
94
+ x = self.elu(self.conv1(x))
95
+ x = self.batchnorm(x)
96
+ x = self.elu(self.conv2(x)) + res
97
+ x = self.batchnorm2(x)
98
+ x = self.elu(self.conv3(x))
99
+ x = self.batchnorm3(x)
100
+ x = self.elu(self.conv4(x)) + res
101
+ x = self.batchnorm4(x)
102
+ x = self.elu(self.conv5(x))
103
+ x = self.batchnorm5(x)
104
+ x = self.elu(self.conv6(x)) + res
105
+ x = self.batchnorm6(x)
106
+ x = self.elu(self.conv7(x))
107
+ x = self.batchnorm7(x)
108
+ x = self.elu(self.conv8(x))
109
+ x = self.batchnorm8(x)
110
+ x = self.elu(self.conv9(x))
111
+ x = self.batchnorm9(x)
112
+ out = self.elu(self.conv10(x)) + res
113
+ return out
114
+
115
+
116
+ def tensor_jax(x):
117
+ if x.dim() == 3:
118
+ x = x. unsqueeze(0)
119
+
120
+ x_np = x.detach().permute(0, 2, 3, 1).cpu().numpy() # Convert from (N, C, H, W) to (N, H, W, C) and move to CPU
121
+ x_jax = jnp.array(x_np)
122
+ return x_jax
123
+
124
+ def jax_to_tensor(x):
125
+ x_tensor = torch.tensor(np.array(x),requires_grad=True).permute(0, 3, 1, 2).to(device) # Convert from (N, H, W, C) to (N, C, H, W)
126
+ return x_tensor
127
+
128
+ # Define the transform
129
+ transform = T.Compose([
130
+ T.Resize((256, 256)),
131
+ T.ToTensor()
132
+ ])
133
+
134
+ def gen_sources(img):
135
+ model_name = "dalle-mini/vqgan_imagenet_f16_16384"
136
+ model_vaq = VQModel.from_pretrained(model_name)
137
+
138
+ model_z1 = Model_Z1()
139
+ model_z1 = model_z1.to(device)
140
+ model_z1.load_state_dict(torch.load("model_z1.pth",map_location=device))
141
+
142
+ model_z2 = Model_Z()
143
+ model_z2 = model_z2.to(device)
144
+ model_z2.load_state_dict(torch.load("model_z2.pth",map_location=device))
145
+
146
+ model_zdf = Model_Z()
147
+ model_zdf = model_zdf.to(device)
148
+ model_zdf.load_state_dict(torch.load("/model_zdf.pth",map_location=device))
149
+
150
+ criterion = nn.MSELoss()
151
+ model_z1.eval()
152
+ model_z2.eval()
153
+ model_zdf.eval()
154
+
155
+ with torch.no_grad():
156
+ img = img.convert('RGB')
157
+ df_img = transform(img)
158
+ df_img = df_img.unsqueeze(0) # Change shape to (1, 3, 256, 256)
159
+ df_img = df_img.to(device)
160
+ #convert images: tensor --> jax_array
161
+ df_img_jax = tensor_jax(df_img)
162
+ #calculate quantized_code(z) for all images
163
+ z_df,_ = model_vaq.encode(df_img_jax)
164
+ #convert quantized_code(z): jax_array --> tensor
165
+ z_df_tensor = jax_to_tensor(z_df)
166
+ ##----------------------------------------------------------------------
167
+ ##----------------------model_z1-----------------------
168
+ outputs_z1 = model_z1(z_df_tensor)
169
+ #generate img1
170
+ z1_rec_jax = tensor_jax(outputs_z1)
171
+ rec_img1 = model_vaq.decode(z1_rec_jax)
172
+ ##----------------------------------------------------------------------
173
+ ##----------------------model_z2-----------------------
174
+ outputs_z2 = model_z2(z_df_tensor)
175
+ #generate img2
176
+ z2_rec_jax = tensor_jax(outputs_z2)
177
+ rec_img2 = model_vaq.decode(z2_rec_jax)
178
+ ##----------------------------------------------------------------------
179
+ ##----------------------model_zdf-----------------------
180
+ z_rec = outputs_z1 + outputs_z2
181
+ outputs_zdf = model_zdf(z_rec)
182
+ lossdf = criterion(outputs_zdf, z_df_tensor)
183
+ #calculate dfimg reconstruction loss
184
+ zdf_rec_jax = tensor_jax(outputs_zdf)
185
+ rec_df = model_vaq.decode(zdf_rec_jax)
186
+ rec_df_tensor = jax_to_tensor(rec_df)
187
+ dfimgloss = criterion(rec_df_tensor, df_img)
188
+ # Convert tensor back to a PIL image
189
+ rec_img1 = jax_to_tensor(rec_img1)
190
+ rec_img1 = rec_img1.squeeze(0)
191
+ rec_img2 = jax_to_tensor(rec_img2)
192
+ rec_img2 = rec_img2.squeeze(0)
193
+ rec_df = jax_to_tensor(rec_df)
194
+ rec_df = rec_df.squeeze(0)
195
+ rec_img1_pil = T.ToPILImage()(rec_img1)
196
+ rec_img2_pil = T.ToPILImage()(rec_img2)
197
+ rec_df_pil = T.ToPILImage()(rec_df)
198
+
199
+ return (rec_img1_pil, rec_img2_pil, round(dfimgloss.item(),3))
200
+
201
+ # Create the Gradio interface
202
+ interface = gr.Interface(
203
+ fn=gen_sources,
204
+ inputs=gr.Image(type="pil", label="Input Image"),
205
+ outputs=[
206
+ gr.Image(type="pil", label="Source Image 1"),
207
+ gr.Image(type="pil", label="Source Image 2"),
208
+ #gr.Image(type="pil", label="Deepfake Image"),
209
+ gr.Number(label="Reconstruction Loss")
210
+ ],
211
+ examples = [["df1.jpg"],["df2.jpg"],["df3.jpg"],["df4.jpg"]],
212
+ theme = gr.themes.Soft(),
213
+ title="Uncovering Deepfake Image",
214
+ description="Upload an image.",
215
+ )
216
+
217
+ interface.launch(debug=True)