File size: 5,491 Bytes
c9224f7
6fdc8f3
c9224f7
 
 
 
 
e7605fc
c9224f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac670c
 
7430abb
 
 
 
 
 
 
c9224f7
7430abb
c9224f7
50d3a44
c9224f7
 
50d3a44
c9224f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad112da
 
c9224f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import io
from io import BytesIO 
import os
import shutil
import requests
import numpy as np
from PIL import Image, ImageOps
import cv2
import math
import matplotlib.pyplot as plt
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.checkpoint import checkpoint  
from torchvision.models import vgg16
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.functional import structural_similarity_index_measure
from facenet_pytorch import InceptionResnetV1
from taming.models.vqgan import VQModel
from omegaconf import OmegaConf
from taming.models.vqgan import GumbelVQ
import gradio as gr
from modules.finetunedvqgan import Generator
from modules.modelz import DeepfakeToSourceTransformer
from modules.frameworkeval import DF
from modules.segmentface import FaceSegmenter
from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##________________________Transformation______________________________

transform = T.Compose([
    T.Resize((256, 256)),   
    T.ToTensor(),         
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])  # Normalize to [-1, 1]

#_________________Define:Gradio Function________________________

def gen_sources(deepfake_img):
    #----------------DeepFake Face Segmentation-----------------
    ##----------------------Initialize:Face Segmentation----------------------------------
    segmenter = FaceSegmenter(threshold=0.5)
    # Convert PIL Image to BGR numpy array for segmentation
    img_np = np.array(deepfake_img.convert('RGB'))
    img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    # Segment the face
    segmented_np = segmenter.segment_face(img_bgr)
    # Convert segmented numpy array (BGR) back to PIL Image
    deepfake_seg = Image.fromarray(cv2.cvtColor(segmented_np, cv2.COLOR_BGR2RGB))
    #------------Initialize:Decoder-F------------------------
    config_path = "./models/config.yaml"
    checkpoint_path_f = "./models/model_vaq1_ff.pth"
    model_vaq_f = Generator(config_path, checkpoint_path_f, device)
    #------------Initialize:Decoder-G------------------------
    checkpoint_path_g = "./models/model_vaq2_gg.pth"
    model_vaq_g = Generator(config_path, checkpoint_path_g, device)
    ##------------------------Initialize Model-F-------------------------------------
    model_z1 = DeepfakeToSourceTransformer().to(device)
    model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True)
    model_z1.eval()
    ##------------------------Initialize Model-G-------------------------------------
    model_z2 = DeepfakeToSourceTransformer().to(device)
    model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth",map_location=device),strict=True)
    model_z2.eval()
    ##--------------------Initialize:Evaluation---------------------------------------
    criterion = DF()
    
    ##----------------------Operation-------------------------------------------------
    with torch.no_grad():
        # Load and preprocess input image
        img = Image.open(deepfake_img).convert('RGB')
        segimg = Image.open(deepfake_seg).convert('RGB')
        df_img = transform(img).unsqueeze(0).to(device)  # Shape: (1, 3, 256, 256)
        seg_img = transform(segimg).unsqueeze(0).to(device)
        
        # Calculate quantized_block for all images
        z_df, _, _ = model_vaq_f.encode(df_img) 
        z_seg, _, _ = model_vaq_g.encode(seg_img) 
        rec_z_img1 = model_z1(z_df) 
        rec_z_img2 = model_z2(z_seg) 
        rec_img1 = model_vaq_f.decode(rec_z_img1).squeeze(0)
        rec_img2 = model_vaq_g.decode(rec_z_img2).squeeze(0)
        rec_img1_pil = T.ToPILImage()(rec_img1)
        rec_img2_pil = T.ToPILImage()(rec_img2)

        # Save PIL images to in-memory buffers
        buffer1 = BytesIO()
        buffer2 = BytesIO()
        rec_img1_pil.save(buffer1, format="PNG")
        rec_img2_pil.save(buffer2, format="PNG")

        # Pass buffers to Gradio client
        result = client.predict(
            target=file(buffer1),
            source=file(buffer2), slider=100, adv_slider=100,
            settings=["Adversarial Defense"], api_name="/run_inference"
        )

        # Load result and compute loss
        dfimage_pil = Image.open(result)  # Open the resulting image
        buffer3 = BytesIO()
        dfimage_pil.save(buffer3, format="PNG")
        rec_df = transform(Image.open(buffer3)).unsqueeze(0).to(device)
        rec_loss,_ = criterion(df_img, rec_df)

        return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3))

#________________________Create the Gradio interface_________________________________
interface = gr.Interface(
    fn=gen_sources,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=[
        gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"),
        gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"),
        gr.Image(type="pil", label="Reconstructed Deepfake Image"),
        gr.Number(label="Reconstruction Loss")
    ],
    examples = ["./images/df1.jpg","./images/df2.jpg","./images/df3.jpg","./images/df4.jpg"],
    theme = gr.themes.Soft(),
    title="Uncovering Deepfake Image for Identifying Source Images",
    description="Upload an DeepFake image.",
)

interface.launch(debug=True)