Spaces:
Sleeping
Sleeping
import io | |
import os | |
import shutil | |
import requests | |
import numpy as np | |
from PIL import Image, ImageOps | |
import math | |
import matplotlib.pyplot as plt | |
import pickle | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
import torchvision.transforms.functional as TF | |
from torch.utils.checkpoint import checkpoint | |
from torchvision.models import vgg16 | |
from torchmetrics.image.fid import FrechetInceptionDistance | |
from torchmetrics.functional import structural_similarity_index_measure | |
from facenet_pytorch import InceptionResnetV1 | |
from taming.models.vqgan import VQModel | |
from omegaconf import OmegaConf | |
from taming.models.vqgan import GumbelVQ | |
import gradio as gr | |
from finetunedvqgan import Generator | |
from modelz import DeepfakeToSourceTransformer | |
from frameworkeval import DF | |
from segmentface import FaceSegmenter | |
from denormalize import denormalize_bin, denormalize_tr, denormalize_ar | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
##________________________Transformation______________________________ | |
transform = T.Compose([ | |
T.Resize((256, 256)), | |
T.ToTensor(), | |
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) # Normalize to [-1, 1] | |
#_________________Define:Gradio Function________________________ | |
def gen_sources(deepfake_img): | |
#----------------DeepFake Face Segmentation----------------- | |
deepfake_seg = segmenter.segment_face(deepfake_img) | |
config_path = "./config.yaml" | |
#------------Initialize:Decoder-F------------------------ | |
checkpoint_path_f = "./model_vaq1_ff.pth" | |
checkpoint_f = torch.load(checkpoint_path_f, map_location=device) | |
model_vaq_f = Generator(config_path) | |
model_vaq_f = model_vaq_f.load_state_dict(checkpoint_f, strict=True) | |
model_vaq_f.eval() | |
#------------Initialize:Decoder-G------------------------ | |
checkpoint_path_g = "./model_vaq2_gg.pth" | |
checkpoint_g = torch.load(checkpoint_path_g, map_location=device) | |
model_vaq_g = Generator(config_path) | |
model_vaq_g = model_vaq_g.load_state_dict(checkpoint_g, strict=True) | |
model_vaq_g.eval() | |
##------------------------Initialize Model-F------------------------------------- | |
model_z1 = DeepfakeToSourceTransformer().to(device) | |
model_z1.load_state_dict(torch.load("./model_z1_ff.pth",map_location=device),strict=True) | |
model_z1.eval() | |
##------------------------Initialize Model-G------------------------------------- | |
model_z2 = DeepfakeToSourceTransformer().to(device) | |
model_z2.load_state_dict(torch.load("./model_z2_gg.pth",map_location=device),strict=True) | |
model_z2.eval() | |
##--------------------Initialize:Evaluation--------------------------------------- | |
criterion = DF() | |
##----------------------Initialize:Face Segmentation---------------------------------- | |
segmenter = FaceSegmenter(threshold=0.5) | |
##----------------------Operation------------------------------------------------- | |
with torch.no_grad(): | |
# Load and preprocess input image | |
img = Image.open(deepfake_img).convert('RGB') | |
segimg = Image.open(deepfake_seg).convert('RGB') | |
df_img = transform(img).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256) | |
seg_img = transform(segimg).unsqueeze(0).to(device) | |
# Calculate quantized_block for all images | |
z_df, _, _ = model_vaq_f.encode(df_img) | |
z_seg, _, _ = model_vaq_g.encode(seg_img) | |
rec_z_img1 = model_z1(z_df) | |
rec_z_img2 = model_z2(z_seg) | |
rec_img1 = model_vaq_f.decode(rec_z_img1) | |
rec_img2 = model_vaq_g.decode(rec_z_img2) | |
rec_img1 = rec_img1.squeeze(0) | |
rec_img2 = rec_img2.squeeze(0) | |
rec_img1_pil = T.ToPILImage()(rec_img1) | |
rec_img2_pil = T.ToPILImage()(rec_img2) | |
# Save PIL images to in-memory buffers | |
buffer1 = BytesIO() | |
buffer2 = BytesIO() | |
rec_img1_pil.save(buffer1, format="PNG") | |
rec_img2_pil.save(buffer2, format="PNG") | |
# Pass buffers to Gradio client | |
result = client.predict( | |
target=file(buffer1), | |
source=file(buffer2), slider=100, adv_slider=100, | |
settings=["Adversarial Defense"], api_name="/run_inference" | |
) | |
# Load result and compute loss | |
dfimage_pil = Image.open(result) # Open the resulting image | |
buffer3 = BytesIO() | |
dfimage_pil.save(buffer3, format="PNG") | |
rec_df = transform(Image.open(buffer3)).unsqueeze(0).to(device) | |
rec_loss,_ = criterion(df_img, rec_df) | |
return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3)) | |
#________________________Create the Gradio interface_________________________________ | |
interface = gr.Interface( | |
fn=gen_sources, | |
inputs=gr.Image(type="pil", label="Input Image"), | |
outputs=[ | |
gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"), | |
gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"), | |
gr.Image(type="pil", label="Reconstructed Deepfake Image"), | |
gr.Number(label="Reconstruction Loss") | |
], | |
examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"], | |
theme = gr.themes.Soft(), | |
title="Uncovering Deepfake Image for Identifying Source Images", | |
description="Upload an DeepFake image.", | |
) | |
interface.launch(debug=True) |