File size: 6,086 Bytes
c9224f7 6fdc8f3 6d3fccb c9224f7 e7605fc c9224f7 5db35c9 c9224f7 5db35c9 c9224f7 97cfc26 e49aff4 c9224f7 ca388ea c9224f7 2ac670c af3f95c 7430abb c6178c5 c9224f7 7c7f404 9fdd782 e49aff4 c6178c5 c9224f7 7c7f404 9fdd782 e49aff4 c6178c5 c9224f7 c6178c5 c9224f7 c6178c5 c9224f7 c6178c5 c9224f7 c6178c5 c9224f7 c6178c5 c9224f7 af3f95c 038b7f1 c6178c5 ad112da 2a681ac c9224f7 c6178c5 c9224f7 c6178c5 c9224f7 c6178c5 c9224f7 c6178c5 c9224f7 c6178c5 2a681ac c6178c5 c9224f7 c6178c5 c9224f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import io
from io import BytesIO
import tempfile
import os
import shutil
import requests
import numpy as np
from PIL import Image, ImageOps
import cv2
import math
import matplotlib.pyplot as plt
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.checkpoint import checkpoint
from torchvision.models import vgg16
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.functional import structural_similarity_index_measure
from facenet_pytorch import InceptionResnetV1
from taming.models.vqgan import VQModel
from omegaconf import OmegaConf
from taming.models.vqgan import GumbelVQ
import gradio as gr
from modules.modelz import DeepfakeToSourceTransformer
from modules.frameworkeval import DF
from modules.segmentface import FaceSegmenter
from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar
from gradio_client import Client, file
client = Client("felixrosberg/face-swap")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config_path = "./models/config.yaml"
config = OmegaConf.load(config_path)
# Extract parameters specific to GumbelVQ
vq_params = config.model.params
# Initialize the GumbelVQ models
model_vaq_f = GumbelVQ(
ddconfig=vq_params.ddconfig,
lossconfig=vq_params.lossconfig,
n_embed=vq_params.n_embed,
embed_dim=vq_params.embed_dim,
kl_weight=vq_params.kl_weight,
temperature_scheduler_config=vq_params.temperature_scheduler_config).to(device)
model_vaq_g = GumbelVQ(
ddconfig=vq_params.ddconfig,
lossconfig=vq_params.lossconfig,
n_embed=vq_params.n_embed,
embed_dim=vq_params.embed_dim,
kl_weight=vq_params.kl_weight,
temperature_scheduler_config=vq_params.temperature_scheduler_config).to(device)
##________________________Transformation______________________________
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) # Normalize to [-1, 1]
#_________________Define:Gradio Function________________________
def gen_sources(deepfake_img):
#----------------DeepFake Face Segmentation-----------------
segmenter = FaceSegmenter(threshold=0.5)
img_np = np.array(deepfake_img.convert('RGB'))
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
segmented_np = segmenter.segment_face(img_bgr)
deepfake_seg = Image.fromarray(cv2.cvtColor(segmented_np, cv2.COLOR_BGR2RGB))
#------------Initialize Models------------------------
checkpoint_path_f = "./models/model_vaq1_ff.pth"
checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
model_vaq_f.load_state_dict(checkpoint_f, strict=True)
model_vaq_f.eval()
checkpoint_path_g = "./models/model_vaq2_gg.pth"
checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
model_vaq_g.load_state_dict(checkpoint_g, strict=True)
model_vaq_g.eval()
model_z1 = DeepfakeToSourceTransformer().to(device)
model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth", map_location=device), strict=True)
model_z1.eval()
model_z2 = DeepfakeToSourceTransformer().to(device)
model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth", map_location=device), strict=True)
model_z2.eval()
criterion = DF()
with torch.no_grad():
df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device)
seg_img = transform(deepfake_seg).unsqueeze(0).to(device)
z_df, _, _ = model_vaq_f.encode(df_img)
z_seg, _, _ = model_vaq_g.encode(seg_img)
rec_z_img1 = model_z1(z_df)
rec_z_img2 = model_z2(z_seg)
rec_img1 = model_vaq_f.decode(rec_z_img1).squeeze(0)
rec_img2 = model_vaq_g.decode(rec_z_img2).squeeze(0)
rec_img1_pil = T.ToPILImage()(denormalize_bin(rec_img1))
rec_img2_pil = T.ToPILImage()(denormalize_bin(rec_img2))
# Save PIL images to temporary files
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp1, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp2:
rec_img1_pil.save(temp1, format="PNG")
rec_img2_pil.save(temp2, format="PNG")
temp1_path = temp1.name
temp2_path = temp2.name
# Pass file paths to Gradio client
result = client.predict(
target=file(temp1_path),
source=file(temp2_path), slider=100, adv_slider=100,
settings=["Adversarial Defense"], api_name="/run_inference"
)
# Clean up temporary files
os.remove(temp1_path)
os.remove(temp2_path)
# Load result and compute loss
dfimage_pil = Image.open(result)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp3:
dfimage_pil.save(temp3, format="PNG")
rec_df = denormalize_bin(transform(Image.open(temp3.name))).unsqueeze(0).to(device)
os.remove(temp3.name)
rec_loss, _ = criterion(df_img, rec_df)
return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(), 3))
#________________________Create the Gradio interface_________________________________
interface = gr.Interface(
fn=gen_sources,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[
gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"),
gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"),
gr.Image(type="pil", label="Reconstructed Deepfake Image"),
gr.Number(label="Reconstruction Loss")
],
examples = ["./images/df1.jpg","./images/df2.jpg","./images/df3.jpg","./images/df4.jpg"],
theme = gr.themes.Soft(),
title="Uncovering Deepfake Image for Identifying Source Images",
description="Upload an DeepFake image.",
)
interface.launch(debug=True) |