Spaces:
Sleeping
Sleeping
File size: 5,352 Bytes
cae99db da52bc4 cae99db da52bc4 cae99db da52bc4 cae99db da52bc4 cae99db da52bc4 cae99db da52bc4 cae99db da52bc4 cae99db da52bc4 cae99db da52bc4 cae99db da52bc4 cae99db fb1bd05 cae99db da52bc4 cae99db da52bc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import io
import os
import shutil
import requests
import numpy as np
from PIL import Image, ImageOps
import math
import matplotlib.pyplot as plt
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.checkpoint import checkpoint
from torchvision.models import vgg16
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.functional import structural_similarity_index_measure
from facenet_pytorch import InceptionResnetV1
from taming.models.vqgan import VQModel
from omegaconf import OmegaConf
from taming.models.vqgan import GumbelVQ
import gradio as gr
from finetunedvqgan import Generator
from modelz import DeepfakeToSourceTransformer
from frameworkeval import DF
from segmentface import FaceSegmenter
from denormalize import denormalize_bin, denormalize_tr, denormalize_ar
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
##________________________Transformation______________________________
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) # Normalize to [-1, 1]
#_________________Define:Gradio Function________________________
def gen_sources(deepfake_img):
#----------------DeepFake Face Segmentation-----------------
deepfake_seg = segmenter.segment_face(deepfake_img)
config_path = "./config.yaml"
#------------Initialize:Decoder-F------------------------
checkpoint_path_f = "./model_vaq1_ff.pth"
checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
model_vaq_f = Generator(config_path)
model_vaq_f = model_vaq_f.load_state_dict(checkpoint_f, strict=True)
model_vaq_f.eval()
#------------Initialize:Decoder-G------------------------
checkpoint_path_g = "./model_vaq2_gg.pth"
checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
model_vaq_g = Generator(config_path)
model_vaq_g = model_vaq_g.load_state_dict(checkpoint_g, strict=True)
model_vaq_g.eval()
##------------------------Initialize Model-F-------------------------------------
model_z1 = DeepfakeToSourceTransformer().to(device)
model_z1.load_state_dict(torch.load("./model_z1_ff.pth",map_location=device),strict=True)
model_z1.eval()
##------------------------Initialize Model-G-------------------------------------
model_z2 = DeepfakeToSourceTransformer().to(device)
model_z2.load_state_dict(torch.load("./model_z2_gg.pth",map_location=device),strict=True)
model_z2.eval()
##--------------------Initialize:Evaluation---------------------------------------
criterion = DF()
##----------------------Initialize:Face Segmentation----------------------------------
segmenter = FaceSegmenter(threshold=0.5)
##----------------------Operation-------------------------------------------------
with torch.no_grad():
# Load and preprocess input image
img = Image.open(deepfake_img).convert('RGB')
segimg = Image.open(deepfake_seg).convert('RGB')
df_img = transform(img).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
seg_img = transform(segimg).unsqueeze(0).to(device)
# Calculate quantized_block for all images
z_df, _, _ = model_vaq_f.encode(df_img)
z_seg, _, _ = model_vaq_g.encode(seg_img)
rec_z_img1 = model_z1(z_df)
rec_z_img2 = model_z2(z_seg)
rec_img1 = model_vaq_f.decode(rec_z_img1)
rec_img2 = model_vaq_g.decode(rec_z_img2)
rec_img1 = rec_img1.squeeze(0)
rec_img2 = rec_img2.squeeze(0)
rec_img1_pil = T.ToPILImage()(rec_img1)
rec_img2_pil = T.ToPILImage()(rec_img2)
# Save PIL images to in-memory buffers
buffer1 = BytesIO()
buffer2 = BytesIO()
rec_img1_pil.save(buffer1, format="PNG")
rec_img2_pil.save(buffer2, format="PNG")
# Pass buffers to Gradio client
result = client.predict(
target=file(buffer1),
source=file(buffer2), slider=100, adv_slider=100,
settings=["Adversarial Defense"], api_name="/run_inference"
)
# Load result and compute loss
dfimage_pil = Image.open(result) # Open the resulting image
buffer3 = BytesIO()
dfimage_pil.save(buffer3, format="PNG")
rec_df = transform(Image.open(buffer3)).unsqueeze(0).to(device)
rec_loss,_ = criterion(df_img, rec_df)
return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3))
#________________________Create the Gradio interface_________________________________
interface = gr.Interface(
fn=gen_sources,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[
gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"),
gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"),
gr.Image(type="pil", label="Reconstructed Deepfake Image"),
gr.Number(label="Reconstruction Loss")
],
examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"],
theme = gr.themes.Soft(),
title="Uncovering Deepfake Image for Identifying Source Images",
description="Upload an DeepFake image.",
)
interface.launch() |