File size: 6,086 Bytes
c9224f7
6fdc8f3
6d3fccb
c9224f7
 
 
 
 
e7605fc
c9224f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5db35c9
c9224f7
5db35c9
c9224f7
 
97cfc26
 
e49aff4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9224f7
 
 
 
 
ca388ea
c9224f7
 
 
 
 
2ac670c
af3f95c
7430abb
 
 
c6178c5
 
c9224f7
7c7f404
9fdd782
e49aff4
c6178c5
c9224f7
7c7f404
9fdd782
e49aff4
c6178c5
c9224f7
c6178c5
c9224f7
c6178c5
c9224f7
c6178c5
c9224f7
c6178c5
c9224f7
c6178c5
c9224f7
af3f95c
038b7f1
c6178c5
 
 
 
 
ad112da
 
2a681ac
 
 
c9224f7
c6178c5
 
 
 
 
 
 
 
 
c9224f7
c6178c5
c9224f7
c6178c5
 
c9224f7
 
 
c6178c5
 
 
 
c9224f7
c6178c5
 
 
2a681ac
c6178c5
 
 
c9224f7
c6178c5
c9224f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import io
from io import BytesIO 
import tempfile
import os
import shutil
import requests
import numpy as np
from PIL import Image, ImageOps
import cv2
import math
import matplotlib.pyplot as plt
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.checkpoint import checkpoint  
from torchvision.models import vgg16
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.functional import structural_similarity_index_measure
from facenet_pytorch import InceptionResnetV1
from taming.models.vqgan import VQModel
from omegaconf import OmegaConf
from taming.models.vqgan import GumbelVQ
import gradio as gr
from modules.modelz import DeepfakeToSourceTransformer
from modules.frameworkeval import DF
from modules.segmentface import FaceSegmenter
from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar
from gradio_client import Client, file

client = Client("felixrosberg/face-swap")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config_path = "./models/config.yaml"
config = OmegaConf.load(config_path)
# Extract parameters specific to GumbelVQ
vq_params = config.model.params
# Initialize the GumbelVQ models
model_vaq_f = GumbelVQ(
            ddconfig=vq_params.ddconfig,
            lossconfig=vq_params.lossconfig,
            n_embed=vq_params.n_embed,
            embed_dim=vq_params.embed_dim,
            kl_weight=vq_params.kl_weight,
            temperature_scheduler_config=vq_params.temperature_scheduler_config).to(device)
model_vaq_g = GumbelVQ(
            ddconfig=vq_params.ddconfig,
            lossconfig=vq_params.lossconfig,
            n_embed=vq_params.n_embed,
            embed_dim=vq_params.embed_dim,
            kl_weight=vq_params.kl_weight,
            temperature_scheduler_config=vq_params.temperature_scheduler_config).to(device)

##________________________Transformation______________________________

transform = T.Compose([
    T.Resize((256, 256)),   
    T.ToTensor(),         
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])  # Normalize to [-1, 1]

#_________________Define:Gradio Function________________________

def gen_sources(deepfake_img):
    #----------------DeepFake Face Segmentation-----------------
    segmenter = FaceSegmenter(threshold=0.5)
    img_np = np.array(deepfake_img.convert('RGB'))
    img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    segmented_np = segmenter.segment_face(img_bgr)
    deepfake_seg = Image.fromarray(cv2.cvtColor(segmented_np, cv2.COLOR_BGR2RGB))

    #------------Initialize Models------------------------
    checkpoint_path_f = "./models/model_vaq1_ff.pth"
    checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
    model_vaq_f.load_state_dict(checkpoint_f, strict=True)
    model_vaq_f.eval()

    checkpoint_path_g = "./models/model_vaq2_gg.pth"
    checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
    model_vaq_g.load_state_dict(checkpoint_g, strict=True)
    model_vaq_g.eval()

    model_z1 = DeepfakeToSourceTransformer().to(device)
    model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth", map_location=device), strict=True)
    model_z1.eval()

    model_z2 = DeepfakeToSourceTransformer().to(device)
    model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth", map_location=device), strict=True)
    model_z2.eval()

    criterion = DF()

    with torch.no_grad():
        df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device)
        seg_img = transform(deepfake_seg).unsqueeze(0).to(device)

        z_df, _, _ = model_vaq_f.encode(df_img)
        z_seg, _, _ = model_vaq_g.encode(seg_img)
        rec_z_img1 = model_z1(z_df)
        rec_z_img2 = model_z2(z_seg)
        rec_img1 = model_vaq_f.decode(rec_z_img1).squeeze(0)
        rec_img2 = model_vaq_g.decode(rec_z_img2).squeeze(0)
        rec_img1_pil = T.ToPILImage()(denormalize_bin(rec_img1))
        rec_img2_pil = T.ToPILImage()(denormalize_bin(rec_img2))


        # Save PIL images to temporary files
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp1, \
             tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp2:
            
            rec_img1_pil.save(temp1, format="PNG")
            rec_img2_pil.save(temp2, format="PNG")
            
            temp1_path = temp1.name
            temp2_path = temp2.name

        # Pass file paths to Gradio client
        result = client.predict(
            target=file(temp1_path),
            source=file(temp2_path), slider=100, adv_slider=100,
            settings=["Adversarial Defense"], api_name="/run_inference"
        )

        # Clean up temporary files
        os.remove(temp1_path)
        os.remove(temp2_path)

        # Load result and compute loss
        dfimage_pil = Image.open(result)
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp3:
            dfimage_pil.save(temp3, format="PNG")
            rec_df = denormalize_bin(transform(Image.open(temp3.name))).unsqueeze(0).to(device)
            os.remove(temp3.name)

        rec_loss, _ = criterion(df_img, rec_df)

        return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(), 3))

#________________________Create the Gradio interface_________________________________
interface = gr.Interface(
    fn=gen_sources,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=[
        gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"),
        gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"),
        gr.Image(type="pil", label="Reconstructed Deepfake Image"),
        gr.Number(label="Reconstruction Loss")
    ],
    examples = ["./images/df1.jpg","./images/df2.jpg","./images/df3.jpg","./images/df4.jpg"],
    theme = gr.themes.Soft(),
    title="Uncovering Deepfake Image for Identifying Source Images",
    description="Upload an DeepFake image.",
)

interface.launch(debug=True)