Spaces:
Sleeping
Sleeping
16march2025
Browse files
app.py
CHANGED
@@ -2,215 +2,127 @@ import io
|
|
2 |
import os
|
3 |
import shutil
|
4 |
import requests
|
5 |
-
import time
|
6 |
import numpy as np
|
7 |
from PIL import Image, ImageOps
|
8 |
-
from math import nan
|
9 |
import math
|
|
|
10 |
import pickle
|
11 |
-
import warnings
|
12 |
-
warnings.filterwarnings("ignore")
|
13 |
-
|
14 |
import torch
|
15 |
import torch.nn as nn
|
16 |
import torch.nn.functional as F
|
17 |
-
import torch.optim as optim
|
18 |
-
|
19 |
-
from torch.utils.data import Dataset, ConcatDataset, DataLoader
|
20 |
-
from torchvision.datasets import ImageFolder
|
21 |
import torchvision.transforms as T
|
22 |
import torchvision.transforms.functional as TF
|
23 |
-
from torch.
|
24 |
-
import
|
25 |
-
|
26 |
-
|
27 |
-
import
|
28 |
-
from
|
29 |
-
from
|
30 |
-
|
31 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
|
|
|
35 |
|
36 |
-
class Model_Z1(nn.Module):
|
37 |
-
def __init__(self):
|
38 |
-
super(Model_Z1, self).__init__()
|
39 |
-
self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1)
|
40 |
-
self.batchnorm = nn.BatchNorm2d(2048)
|
41 |
-
self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1)
|
42 |
-
self.batchnorm2 = nn.BatchNorm2d(256)
|
43 |
-
self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1)
|
44 |
-
self.batchnorm3 = nn.BatchNorm2d(1024)
|
45 |
-
self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1)
|
46 |
-
self.batchnorm4 = nn.BatchNorm2d(256)
|
47 |
-
self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
|
48 |
-
self.batchnorm5 = nn.BatchNorm2d(512)
|
49 |
-
self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
|
50 |
-
self.elu = nn.ELU()
|
51 |
-
|
52 |
-
def forward(self, x):
|
53 |
-
res = x
|
54 |
-
x = self.elu(self.conv1(x))
|
55 |
-
x = self.batchnorm(x)
|
56 |
-
x = self.elu(self.conv2(x)) + res
|
57 |
-
x = self.batchnorm2(x)
|
58 |
-
x = self.elu(self.conv3(x))
|
59 |
-
x = self.batchnorm3(x)
|
60 |
-
x = self.elu(self.conv4(x)) + res
|
61 |
-
x = self.batchnorm4(x)
|
62 |
-
x = self.elu(self.conv5(x))
|
63 |
-
x = self.batchnorm5(x)
|
64 |
-
out = self.elu(self.conv6(x)) + res
|
65 |
-
return out
|
66 |
-
|
67 |
-
class Model_Z(nn.Module):
|
68 |
-
def __init__(self):
|
69 |
-
super(Model_Z, self).__init__()
|
70 |
-
self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1)
|
71 |
-
self.batchnorm = nn.BatchNorm2d(2048)
|
72 |
-
self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1)
|
73 |
-
self.batchnorm2 = nn.BatchNorm2d(256)
|
74 |
-
self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1)
|
75 |
-
self.batchnorm3 = nn.BatchNorm2d(1024)
|
76 |
-
self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1)
|
77 |
-
self.batchnorm4 = nn.BatchNorm2d(256)
|
78 |
-
self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
|
79 |
-
self.batchnorm5 = nn.BatchNorm2d(512)
|
80 |
-
self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
|
81 |
-
self.batchnorm6 = nn.BatchNorm2d(256)
|
82 |
-
self.conv7 = nn.Conv2d(in_channels=256, out_channels=448, kernel_size=3, padding=1)
|
83 |
-
self.batchnorm7 = nn.BatchNorm2d(448)
|
84 |
-
self.conv8 = nn.Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1)
|
85 |
-
self.batchnorm8 = nn.BatchNorm2d(384)
|
86 |
-
self.conv9 = nn.Conv2d(in_channels=384, out_channels=320, kernel_size=3, padding=1)
|
87 |
-
self.batchnorm9 = nn.BatchNorm2d(320)
|
88 |
-
self.conv10 = nn.Conv2d(in_channels=320, out_channels=256, kernel_size=3, padding=1)
|
89 |
-
self.elu = nn.ELU()
|
90 |
-
|
91 |
-
def forward(self, x):
|
92 |
-
res = x
|
93 |
-
x = self.elu(self.conv1(x))
|
94 |
-
x = self.batchnorm(x)
|
95 |
-
x = self.elu(self.conv2(x)) + res
|
96 |
-
x = self.batchnorm2(x)
|
97 |
-
x = self.elu(self.conv3(x))
|
98 |
-
x = self.batchnorm3(x)
|
99 |
-
x = self.elu(self.conv4(x)) + res
|
100 |
-
x = self.batchnorm4(x)
|
101 |
-
x = self.elu(self.conv5(x))
|
102 |
-
x = self.batchnorm5(x)
|
103 |
-
x = self.elu(self.conv6(x)) + res
|
104 |
-
x = self.batchnorm6(x)
|
105 |
-
x = self.elu(self.conv7(x))
|
106 |
-
x = self.batchnorm7(x)
|
107 |
-
x = self.elu(self.conv8(x))
|
108 |
-
x = self.batchnorm8(x)
|
109 |
-
x = self.elu(self.conv9(x))
|
110 |
-
x = self.batchnorm9(x)
|
111 |
-
out = self.elu(self.conv10(x)) + res
|
112 |
-
return out
|
113 |
-
|
114 |
-
|
115 |
-
def tensor_jax(x):
|
116 |
-
if x.dim() == 3:
|
117 |
-
x = x. unsqueeze(0)
|
118 |
-
|
119 |
-
x_np = x.detach().permute(0, 2, 3, 1).cpu().numpy() # Convert from (N, C, H, W) to (N, H, W, C) and move to CPU
|
120 |
-
x_jax = jnp.array(x_np)
|
121 |
-
return x_jax
|
122 |
-
|
123 |
-
def jax_to_tensor(x):
|
124 |
-
x_tensor = torch.tensor(np.array(x),requires_grad=True).permute(0, 3, 1, 2).to(device) # Convert from (N, H, W, C) to (N, C, H, W)
|
125 |
-
return x_tensor
|
126 |
-
|
127 |
-
# Define the transform
|
128 |
transform = T.Compose([
|
129 |
-
T.Resize((256, 256)),
|
130 |
-
T.ToTensor()
|
131 |
-
])
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
150 |
model_z1.eval()
|
|
|
|
|
|
|
151 |
model_z2.eval()
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
154 |
with torch.no_grad():
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
df_img =
|
159 |
-
|
160 |
-
|
161 |
-
#
|
162 |
-
z_df,_ =
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
#generate img1
|
169 |
-
z1_rec_jax = tensor_jax(outputs_z1)
|
170 |
-
rec_img1 = model_vaq.decode(z1_rec_jax)
|
171 |
-
##----------------------------------------------------------------------
|
172 |
-
##----------------------model_z2-----------------------
|
173 |
-
outputs_z2 = model_z2(z_df_tensor)
|
174 |
-
#generate img2
|
175 |
-
z2_rec_jax = tensor_jax(outputs_z2)
|
176 |
-
rec_img2 = model_vaq.decode(z2_rec_jax)
|
177 |
-
##----------------------------------------------------------------------
|
178 |
-
##----------------------model_zdf-----------------------
|
179 |
-
z_rec = outputs_z1 + outputs_z2
|
180 |
-
outputs_zdf = model_zdf(z_rec)
|
181 |
-
lossdf = criterion(outputs_zdf, z_df_tensor)
|
182 |
-
#calculate dfimg reconstruction loss
|
183 |
-
zdf_rec_jax = tensor_jax(outputs_zdf)
|
184 |
-
rec_df = model_vaq.decode(zdf_rec_jax)
|
185 |
-
rec_df_tensor = jax_to_tensor(rec_df)
|
186 |
-
dfimgloss = criterion(rec_df_tensor, df_img)
|
187 |
-
# Convert tensor back to a PIL image
|
188 |
-
rec_img1 = jax_to_tensor(rec_img1)
|
189 |
rec_img1 = rec_img1.squeeze(0)
|
190 |
-
rec_img2 = jax_to_tensor(rec_img2)
|
191 |
rec_img2 = rec_img2.squeeze(0)
|
192 |
-
rec_df = jax_to_tensor(rec_df)
|
193 |
-
rec_df = rec_df.squeeze(0)
|
194 |
rec_img1_pil = T.ToPILImage()(rec_img1)
|
195 |
rec_img2_pil = T.ToPILImage()(rec_img2)
|
196 |
-
rec_df_pil = T.ToPILImage()(rec_df)
|
197 |
-
|
198 |
-
return (rec_img1_pil, rec_img2_pil, round(dfimgloss.item(),3))
|
199 |
|
200 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
interface = gr.Interface(
|
202 |
fn=gen_sources,
|
203 |
inputs=gr.Image(type="pil", label="Input Image"),
|
204 |
outputs=[
|
205 |
-
gr.Image(type="pil", label="Source Image 1"),
|
206 |
-
gr.Image(type="pil", label="Source Image 2"),
|
207 |
-
|
208 |
gr.Number(label="Reconstruction Loss")
|
209 |
],
|
210 |
examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"],
|
211 |
theme = gr.themes.Soft(),
|
212 |
-
title="Uncovering Deepfake Image",
|
213 |
-
description="Upload an image.",
|
214 |
)
|
215 |
|
216 |
-
interface.launch()
|
|
|
2 |
import os
|
3 |
import shutil
|
4 |
import requests
|
|
|
5 |
import numpy as np
|
6 |
from PIL import Image, ImageOps
|
|
|
7 |
import math
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
import pickle
|
|
|
|
|
|
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
13 |
import torchvision.transforms as T
|
14 |
import torchvision.transforms.functional as TF
|
15 |
+
from torch.utils.checkpoint import checkpoint
|
16 |
+
from torchvision.models import vgg16
|
17 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
18 |
+
from torchmetrics.functional import structural_similarity_index_measure
|
19 |
+
from facenet_pytorch import InceptionResnetV1
|
20 |
+
from taming.models.vqgan import VQModel
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from taming.models.vqgan import GumbelVQ
|
23 |
import gradio as gr
|
24 |
+
from finetunedvqgan import Generator
|
25 |
+
from modelz import DeepfakeToSourceTransformer
|
26 |
+
from frameworkeval import DF
|
27 |
+
from segmentface import FaceSegmenter
|
28 |
+
from denormalize import denormalize_bin, denormalize_tr, denormalize_ar
|
29 |
|
30 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
|
32 |
+
##________________________Transformation______________________________
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
transform = T.Compose([
|
35 |
+
T.Resize((256, 256)),
|
36 |
+
T.ToTensor(),
|
37 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) # Normalize to [-1, 1]
|
38 |
+
|
39 |
+
#_________________Define:Gradio Function________________________
|
40 |
+
|
41 |
+
def gen_sources(deepfake_img):
|
42 |
+
#----------------DeepFake Face Segmentation-----------------
|
43 |
+
deepfake_seg = segmenter.segment_face(deepfake_img)
|
44 |
+
config_path = "./config.yaml"
|
45 |
+
#------------Initialize:Decoder-F------------------------
|
46 |
+
checkpoint_path_f = "./model_vaq1_ff.pth"
|
47 |
+
checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
|
48 |
+
model_vaq_f = Generator(config_path)
|
49 |
+
model_vaq_f = model_vaq_f.load_state_dict(checkpoint_f, strict=True)
|
50 |
+
model_vaq_f.eval()
|
51 |
+
#------------Initialize:Decoder-G------------------------
|
52 |
+
checkpoint_path_g = "./model_vaq2_gg.pth"
|
53 |
+
checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
|
54 |
+
model_vaq_g = Generator(config_path)
|
55 |
+
model_vaq_g = model_vaq_g.load_state_dict(checkpoint_g, strict=True)
|
56 |
+
model_vaq_g.eval()
|
57 |
+
##------------------------Initialize Model-F-------------------------------------
|
58 |
+
model_z1 = DeepfakeToSourceTransformer().to(device)
|
59 |
+
model_z1.load_state_dict(torch.load("./model_z1_ff.pth",map_location=device),strict=True)
|
60 |
model_z1.eval()
|
61 |
+
##------------------------Initialize Model-G-------------------------------------
|
62 |
+
model_z2 = DeepfakeToSourceTransformer().to(device)
|
63 |
+
model_z2.load_state_dict(torch.load("./model_z2_gg.pth",map_location=device),strict=True)
|
64 |
model_z2.eval()
|
65 |
+
##--------------------Initialize:Evaluation---------------------------------------
|
66 |
+
criterion = DF()
|
67 |
+
##----------------------Initialize:Face Segmentation----------------------------------
|
68 |
+
segmenter = FaceSegmenter(threshold=0.5)
|
69 |
+
|
70 |
+
##----------------------Operation-------------------------------------------------
|
71 |
with torch.no_grad():
|
72 |
+
# Load and preprocess input image
|
73 |
+
img = Image.open(deepfake_img).convert('RGB')
|
74 |
+
segimg = Image.open(deepfake_seg).convert('RGB')
|
75 |
+
df_img = transform(img).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
|
76 |
+
seg_img = transform(segimg).unsqueeze(0).to(device)
|
77 |
+
|
78 |
+
# Calculate quantized_block for all images
|
79 |
+
z_df, _, _ = model_vaq_f.encode(df_img)
|
80 |
+
z_seg, _, _ = model_vaq_g.encode(seg_img)
|
81 |
+
rec_z_img1 = model_z1(z_df)
|
82 |
+
rec_z_img2 = model_z2(z_seg)
|
83 |
+
rec_img1 = model_vaq_f.decode(rec_z_img1)
|
84 |
+
rec_img2 = model_vaq_g.decode(rec_z_img2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
rec_img1 = rec_img1.squeeze(0)
|
|
|
86 |
rec_img2 = rec_img2.squeeze(0)
|
|
|
|
|
87 |
rec_img1_pil = T.ToPILImage()(rec_img1)
|
88 |
rec_img2_pil = T.ToPILImage()(rec_img2)
|
|
|
|
|
|
|
89 |
|
90 |
+
# Save PIL images to in-memory buffers
|
91 |
+
buffer1 = BytesIO()
|
92 |
+
buffer2 = BytesIO()
|
93 |
+
rec_img1_pil.save(buffer1, format="PNG")
|
94 |
+
rec_img2_pil.save(buffer2, format="PNG")
|
95 |
+
|
96 |
+
# Pass buffers to Gradio client
|
97 |
+
result = client.predict(
|
98 |
+
target=file(buffer1),
|
99 |
+
source=file(buffer2), slider=100, adv_slider=100,
|
100 |
+
settings=["Adversarial Defense"], api_name="/run_inference"
|
101 |
+
)
|
102 |
+
|
103 |
+
# Load result and compute loss
|
104 |
+
dfimage_pil = Image.open(result) # Open the resulting image
|
105 |
+
buffer3 = BytesIO()
|
106 |
+
dfimage_pil.save(buffer3, format="PNG")
|
107 |
+
rec_df = transform(Image.open(buffer3)).unsqueeze(0).to(device)
|
108 |
+
rec_loss,_ = criterion(df_img, rec_df)
|
109 |
+
|
110 |
+
return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3))
|
111 |
+
|
112 |
+
#________________________Create the Gradio interface_________________________________
|
113 |
interface = gr.Interface(
|
114 |
fn=gen_sources,
|
115 |
inputs=gr.Image(type="pil", label="Input Image"),
|
116 |
outputs=[
|
117 |
+
gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"),
|
118 |
+
gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"),
|
119 |
+
gr.Image(type="pil", label="Reconstructed Deepfake Image"),
|
120 |
gr.Number(label="Reconstruction Loss")
|
121 |
],
|
122 |
examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"],
|
123 |
theme = gr.themes.Soft(),
|
124 |
+
title="Uncovering Deepfake Image for Identifying Source Images",
|
125 |
+
description="Upload an DeepFake image.",
|
126 |
)
|
127 |
|
128 |
+
interface.launch()
|