2025_march16
Browse files- app.py +128 -0
- images/df1.jpg +0 -0
- images/df2.jpg +0 -0
- images/df3.jpg +0 -0
- images/df4.jpg +0 -0
- modules/.ipynb_checkpoints/denormalize-checkpoint.py +21 -0
- modules/.ipynb_checkpoints/finetunedvqgan-checkpoint.py +31 -0
- modules/.ipynb_checkpoints/frameworkeval-checkpoint.py +56 -0
- modules/.ipynb_checkpoints/modelz-checkpoint.py +155 -0
- modules/.ipynb_checkpoints/segmentface-checkpoint.py +75 -0
- modules/denormalize.py +21 -0
- modules/finetunedvqgan.py +31 -0
- modules/frameworkeval.py +56 -0
- modules/modelz.py +155 -0
- modules/segmentface.py +75 -0
app.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import requests
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageOps
|
7 |
+
import math
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import pickle
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torchvision.transforms as T
|
14 |
+
import torchvision.transforms.functional as TF
|
15 |
+
from torch.utils.checkpoint import checkpoint
|
16 |
+
from torchvision.models import vgg16
|
17 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
18 |
+
from torchmetrics.functional import structural_similarity_index_measure
|
19 |
+
from facenet_pytorch import InceptionResnetV1
|
20 |
+
from taming.models.vqgan import VQModel
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from taming.models.vqgan import GumbelVQ
|
23 |
+
import gradio as gr
|
24 |
+
from modules.finetunedvqgan import Generator
|
25 |
+
from modules.modelz import DeepfakeToSourceTransformer
|
26 |
+
from modules.frameworkeval import DF
|
27 |
+
from modules.segmentface import FaceSegmenter
|
28 |
+
from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar
|
29 |
+
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
|
32 |
+
##________________________Transformation______________________________
|
33 |
+
|
34 |
+
transform = T.Compose([
|
35 |
+
T.Resize((256, 256)),
|
36 |
+
T.ToTensor(),
|
37 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) # Normalize to [-1, 1]
|
38 |
+
|
39 |
+
#_________________Define:Gradio Function________________________
|
40 |
+
|
41 |
+
def gen_sources(deepfake_img):
|
42 |
+
#----------------DeepFake Face Segmentation-----------------
|
43 |
+
deepfake_seg = segmenter.segment_face(deepfake_img)
|
44 |
+
config_path = "./models/config.yaml"
|
45 |
+
#------------Initialize:Decoder-F------------------------
|
46 |
+
checkpoint_path_f = "./models/model_vaq1_ff.pth"
|
47 |
+
checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
|
48 |
+
model_vaq_f = Generator(config_path, device)
|
49 |
+
model_vaq_f = model_vaq_f.load_state_dict(checkpoint_f, strict=True)
|
50 |
+
model_vaq_f.eval()
|
51 |
+
#------------Initialize:Decoder-G------------------------
|
52 |
+
checkpoint_path_g = "./models/model_vaq2_gg.pth"
|
53 |
+
checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
|
54 |
+
model_vaq_g = Generator(config_path, device)
|
55 |
+
model_vaq_g = model_vaq_g.load_state_dict(checkpoint_g, strict=True)
|
56 |
+
model_vaq_g.eval()
|
57 |
+
##------------------------Initialize Model-F-------------------------------------
|
58 |
+
model_z1 = DeepfakeToSourceTransformer().to(device)
|
59 |
+
model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True)
|
60 |
+
model_z1.eval()
|
61 |
+
##------------------------Initialize Model-G-------------------------------------
|
62 |
+
model_z2 = DeepfakeToSourceTransformer().to(device)
|
63 |
+
model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth",map_location=device),strict=True)
|
64 |
+
model_z2.eval()
|
65 |
+
##--------------------Initialize:Evaluation---------------------------------------
|
66 |
+
criterion = DF()
|
67 |
+
##----------------------Initialize:Face Segmentation----------------------------------
|
68 |
+
segmenter = FaceSegmenter(threshold=0.5)
|
69 |
+
|
70 |
+
##----------------------Operation-------------------------------------------------
|
71 |
+
with torch.no_grad():
|
72 |
+
# Load and preprocess input image
|
73 |
+
img = Image.open(deepfake_img).convert('RGB')
|
74 |
+
segimg = Image.open(deepfake_seg).convert('RGB')
|
75 |
+
df_img = transform(img).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
|
76 |
+
seg_img = transform(segimg).unsqueeze(0).to(device)
|
77 |
+
|
78 |
+
# Calculate quantized_block for all images
|
79 |
+
z_df, _, _ = model_vaq_f.encode(df_img)
|
80 |
+
z_seg, _, _ = model_vaq_g.encode(seg_img)
|
81 |
+
rec_z_img1 = model_z1(z_df)
|
82 |
+
rec_z_img2 = model_z2(z_seg)
|
83 |
+
rec_img1 = model_vaq_f.decode(rec_z_img1)
|
84 |
+
rec_img2 = model_vaq_g.decode(rec_z_img2)
|
85 |
+
rec_img1 = rec_img1.squeeze(0)
|
86 |
+
rec_img2 = rec_img2.squeeze(0)
|
87 |
+
rec_img1_pil = T.ToPILImage()(rec_img1)
|
88 |
+
rec_img2_pil = T.ToPILImage()(rec_img2)
|
89 |
+
|
90 |
+
# Save PIL images to in-memory buffers
|
91 |
+
buffer1 = BytesIO()
|
92 |
+
buffer2 = BytesIO()
|
93 |
+
rec_img1_pil.save(buffer1, format="PNG")
|
94 |
+
rec_img2_pil.save(buffer2, format="PNG")
|
95 |
+
|
96 |
+
# Pass buffers to Gradio client
|
97 |
+
result = client.predict(
|
98 |
+
target=file(buffer1),
|
99 |
+
source=file(buffer2), slider=100, adv_slider=100,
|
100 |
+
settings=["Adversarial Defense"], api_name="/run_inference"
|
101 |
+
)
|
102 |
+
|
103 |
+
# Load result and compute loss
|
104 |
+
dfimage_pil = Image.open(result) # Open the resulting image
|
105 |
+
buffer3 = BytesIO()
|
106 |
+
dfimage_pil.save(buffer3, format="PNG")
|
107 |
+
rec_df = transform(Image.open(buffer3)).unsqueeze(0).to(device)
|
108 |
+
rec_loss,_ = criterion(df_img, rec_df)
|
109 |
+
|
110 |
+
return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3))
|
111 |
+
|
112 |
+
#________________________Create the Gradio interface_________________________________
|
113 |
+
interface = gr.Interface(
|
114 |
+
fn=gen_sources,
|
115 |
+
inputs=gr.Image(type="pil", label="Input Image"),
|
116 |
+
outputs=[
|
117 |
+
gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"),
|
118 |
+
gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"),
|
119 |
+
gr.Image(type="pil", label="Reconstructed Deepfake Image"),
|
120 |
+
gr.Number(label="Reconstruction Loss")
|
121 |
+
],
|
122 |
+
examples = ["./images/df1.jpg","./images/df2.jpg","./images/df3.jpg","./images/df4.jpg"],
|
123 |
+
theme = gr.themes.Soft(),
|
124 |
+
title="Uncovering Deepfake Image for Identifying Source Images",
|
125 |
+
description="Upload an DeepFake image.",
|
126 |
+
)
|
127 |
+
|
128 |
+
interface.launch(debug=True)
|
images/df1.jpg
ADDED
![]() |
images/df2.jpg
ADDED
![]() |
images/df3.jpg
ADDED
![]() |
images/df4.jpg
ADDED
![]() |
modules/.ipynb_checkpoints/denormalize-checkpoint.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
#------------------Denormalization---------------------------------------------
|
5 |
+
def denormalize_bin(tensor):
|
6 |
+
tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
|
7 |
+
tr = tr.add(1).div(2) # Shift to [0, 1]
|
8 |
+
return tr
|
9 |
+
|
10 |
+
def denormalize_tr(tensor):
|
11 |
+
tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
|
12 |
+
tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255]
|
13 |
+
tr = tr.byte() # Convert the tensor to uint8
|
14 |
+
return tr
|
15 |
+
|
16 |
+
def denormalize_ar(tensor):
|
17 |
+
tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
|
18 |
+
tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255]
|
19 |
+
tr = tr.byte() # Convert the tensor to uint8
|
20 |
+
arr = tr.permute(0, 2, 3, 1).cpu().detach().numpy() # Convert to (N, H, W, C) and numpy array
|
21 |
+
return arr
|
modules/.ipynb_checkpoints/finetunedvqgan-checkpoint.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.checkpoint import checkpoint
|
3 |
+
from taming.models.vqgan import VQModel
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from taming.models.vqgan import GumbelVQ
|
6 |
+
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
|
9 |
+
class Generator:
|
10 |
+
def __init__(self, config_path, device=device):
|
11 |
+
self.config_path = config_path
|
12 |
+
self.device = device
|
13 |
+
|
14 |
+
def load_models(self):
|
15 |
+
# Load configuration
|
16 |
+
config = OmegaConf.load(self.config_path)
|
17 |
+
# Extract parameters specific to GumbelVQ
|
18 |
+
vq_params = config.model.params
|
19 |
+
# Initialize the GumbelVQ models
|
20 |
+
model_vaq = GumbelVQ(
|
21 |
+
ddconfig=vq_params.ddconfig,
|
22 |
+
lossconfig=vq_params.lossconfig,
|
23 |
+
n_embed=vq_params.n_embed,
|
24 |
+
embed_dim=vq_params.embed_dim,
|
25 |
+
kl_weight=vq_params.kl_weight,
|
26 |
+
temperature_scheduler_config=vq_params.temperature_scheduler_config,
|
27 |
+
).to(self.device)
|
28 |
+
|
29 |
+
return model_vaq
|
30 |
+
|
31 |
+
|
modules/.ipynb_checkpoints/frameworkeval-checkpoint.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.models import vgg16
|
5 |
+
from torchmetrics.functional import structural_similarity_index_measure
|
6 |
+
from facenet_pytorch import InceptionResnetV1
|
7 |
+
from denormalize import denormalize_bin, denormalize_tr, denormalize_ar
|
8 |
+
|
9 |
+
class DF(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super(DF, self).__init__()
|
12 |
+
self.mse_weight = 0.25
|
13 |
+
self.perceptual_weight = 0.25
|
14 |
+
self.ssim_weight = 0.25
|
15 |
+
self.idsim_weight = 0.25
|
16 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
|
18 |
+
self.facenet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
|
19 |
+
for param in self.facenet.parameters():
|
20 |
+
param.requires_grad = False # Freeze the model
|
21 |
+
self.cosloss = nn.CosineEmbeddingLoss()
|
22 |
+
|
23 |
+
def perceptual_loss(self, real, fake):
|
24 |
+
with torch.no_grad(): # VGG is frozen during training
|
25 |
+
real_features = self.vgg(real)
|
26 |
+
fake_features = self.vgg(fake)
|
27 |
+
return F.mse_loss(real_features, fake_features)
|
28 |
+
|
29 |
+
def idsimilarity(self, real, fake):
|
30 |
+
with torch.no_grad():
|
31 |
+
# Extract embeddings
|
32 |
+
input_embed = self.facenet(real).to(device)
|
33 |
+
generated_embed = self.facenet(fake).to(device)
|
34 |
+
# Compute cosine similarity loss
|
35 |
+
target = torch.ones(input_embed.size(0)).to(real.device) # Target = 1 (maximize similarity)
|
36 |
+
return self.cosloss(input_embed, generated_embed, target)
|
37 |
+
|
38 |
+
def forward(self, r, f):
|
39 |
+
real = denormalize_bin(r) #[-1,1] to [0,1]
|
40 |
+
fake = denormalize_bin(f)
|
41 |
+
mse_loss = F.mse_loss(real, fake)
|
42 |
+
perceptual_loss = self.perceptual_loss(real, fake)
|
43 |
+
idsim_loss = self.idsimilarity(real, fake)
|
44 |
+
ssim = structural_similarity_index_measure(fake, real)
|
45 |
+
ssim_loss = 1 - ssim
|
46 |
+
id_si = 1 - idsim_loss
|
47 |
+
|
48 |
+
total_loss = (self.mse_weight * mse_loss) + (self.perceptual_weight * perceptual_loss) + (self.idsim_weight * idsim_loss) + (self.ssim_weight * ssim_loss)
|
49 |
+
components = {
|
50 |
+
"MSE Loss": mse_loss.item(),
|
51 |
+
"Perceptual Loss": perceptual_loss.item(),
|
52 |
+
"ID-SIM Loss": idsim_loss.item(),
|
53 |
+
"SSIM Loss": ssim_loss.item()
|
54 |
+
}
|
55 |
+
|
56 |
+
return total_loss, components
|
modules/.ipynb_checkpoints/modelz-checkpoint.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
##_____________________Define:MODEL-F & MODEL-G_________________
|
8 |
+
|
9 |
+
# Positional Encoding
|
10 |
+
class PositionalEncoding(nn.Module):
|
11 |
+
def __init__(self, d_model, max_len=1024):
|
12 |
+
super(PositionalEncoding, self).__init__()
|
13 |
+
self.dropout = nn.Dropout(0.1)
|
14 |
+
position = torch.arange(max_len).unsqueeze(1)
|
15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
16 |
+
pe = torch.zeros(max_len, d_model)
|
17 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
18 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
19 |
+
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x + self.pe[:, :x.size(1)]
|
23 |
+
return self.dropout(x)
|
24 |
+
|
25 |
+
# Transformer Encoder
|
26 |
+
class TransformerEncoder(nn.Module):
|
27 |
+
def __init__(self, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1):
|
28 |
+
super(TransformerEncoder, self).__init__()
|
29 |
+
self.positional_encoding = PositionalEncoding(d_model)
|
30 |
+
self.encoder = nn.TransformerEncoder(
|
31 |
+
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout,batch_first=True),
|
32 |
+
num_layers=num_layers
|
33 |
+
)
|
34 |
+
|
35 |
+
def preprocess_latent(self, Z):
|
36 |
+
batch_size, channels, height, width = Z.shape # (batch_size, 256, 32, 32)
|
37 |
+
seq_len = height * width
|
38 |
+
Z = Z.permute(0, 2, 3, 1).reshape(batch_size, seq_len, channels) # (batch_size, 1024, 256)
|
39 |
+
return Z
|
40 |
+
|
41 |
+
def postprocess_latent(self, Z):
|
42 |
+
batch_size, seq_len, channels = Z.shape # (batch_size, 1024, 256)
|
43 |
+
height = width = int(math.sqrt(seq_len))
|
44 |
+
Z = Z.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) # (batch_size, 256, 32, 32)
|
45 |
+
return Z
|
46 |
+
|
47 |
+
def forward(self, Z):
|
48 |
+
Z = self.preprocess_latent(Z)
|
49 |
+
Z = self.positional_encoding(Z)
|
50 |
+
Z = self.encoder(Z)
|
51 |
+
Z = self.postprocess_latent(Z)
|
52 |
+
return Z # latent of transformer
|
53 |
+
|
54 |
+
class TransformerDecoder(nn.Module):
|
55 |
+
def __init__(self, d_model=256, nhead=8, num_layers=12, dim_feedforward=1024, dropout=0.1):
|
56 |
+
super().__init__()
|
57 |
+
self.d_model = d_model
|
58 |
+
|
59 |
+
# Enhanced positional encoding
|
60 |
+
self.positional_encoding = PositionalEncoding(d_model)
|
61 |
+
|
62 |
+
# Multi-layer learnable start tokens
|
63 |
+
self.base_start = nn.Parameter(torch.randn(1, 1024, d_model))
|
64 |
+
self.start_net = nn.Sequential(
|
65 |
+
nn.LayerNorm(d_model),
|
66 |
+
nn.Linear(d_model, dim_feedforward),
|
67 |
+
nn.GELU(),
|
68 |
+
nn.Dropout(dropout),
|
69 |
+
nn.Linear(dim_feedforward, d_model),
|
70 |
+
nn.LayerNorm(d_model)
|
71 |
+
)
|
72 |
+
|
73 |
+
# Context-aware transformer decoder
|
74 |
+
self.decoder = nn.TransformerDecoder(
|
75 |
+
nn.TransformerDecoderLayer(
|
76 |
+
d_model=d_model,
|
77 |
+
nhead=nhead,
|
78 |
+
dim_feedforward=dim_feedforward,
|
79 |
+
dropout=dropout,
|
80 |
+
batch_first=True
|
81 |
+
),
|
82 |
+
num_layers=num_layers
|
83 |
+
)
|
84 |
+
|
85 |
+
# Output projection with residual
|
86 |
+
self.output_layer = nn.Sequential(
|
87 |
+
nn.Linear(d_model, d_model*2),
|
88 |
+
nn.GELU(),
|
89 |
+
nn.Linear(d_model*2, d_model))
|
90 |
+
|
91 |
+
self.init_weights()
|
92 |
+
|
93 |
+
def init_weights(self):
|
94 |
+
for p in self.parameters():
|
95 |
+
if p.dim() > 1:
|
96 |
+
nn.init.xavier_uniform_(p)
|
97 |
+
nn.init.normal_(self.base_start, mean=0, std=0.02)
|
98 |
+
|
99 |
+
def preprocess_latent(self, Z):
|
100 |
+
# Convert (B, C, H, W) to (B, H*W, C)
|
101 |
+
return Z.permute(0, 2, 3, 1).flatten(1, 2)
|
102 |
+
|
103 |
+
def postprocess_latent(self, Z):
|
104 |
+
# Convert (B, H*W, C) back to (B, C, H, W)
|
105 |
+
B, L, C = Z.shape
|
106 |
+
H = W = int(L**0.5)
|
107 |
+
return Z.view(B, H, W, C).permute(0, 3, 1, 2)
|
108 |
+
|
109 |
+
def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
|
110 |
+
# Process input latent
|
111 |
+
Z = self.preprocess_latent(Z)
|
112 |
+
#Z = self.positional_encoding(Z)
|
113 |
+
|
114 |
+
# Generate enhanced start tokens
|
115 |
+
B = Z.size(0)
|
116 |
+
base_tokens = self.base_start.expand(B, -1, -1)
|
117 |
+
processed_start = self.start_net(base_tokens)
|
118 |
+
|
119 |
+
# Teacher forcing integration
|
120 |
+
if Z1_start_tokens is not None and teacher_forcing_ratio > 0:
|
121 |
+
Z1_processed = self.positional_encoding(self.preprocess_latent(Z1_start_tokens))
|
122 |
+
|
123 |
+
# Create mixing mask
|
124 |
+
mask = torch.rand(B, 1, 1, device=Z.device) < teacher_forcing_ratio
|
125 |
+
processed_start = torch.where(mask, Z1_processed, processed_start)
|
126 |
+
|
127 |
+
# Decoder processing with residual
|
128 |
+
decoder_input = self.positional_encoding(processed_start)
|
129 |
+
outputs = self.decoder(decoder_input, Z)
|
130 |
+
outputs = self.output_layer(outputs + decoder_input)
|
131 |
+
|
132 |
+
return self.postprocess_latent(outputs)
|
133 |
+
|
134 |
+
class DeepfakeToSourceTransformer(nn.Module):
|
135 |
+
def __init__(self, d_model=256, encoder_nhead=8, decoder_nhead=8, num_encoder_layers=6, num_decoder_layers=12, dim_feedforward=1024, dropout=0.1):
|
136 |
+
super().__init__()
|
137 |
+
self.encoder = TransformerEncoder(
|
138 |
+
d_model=d_model,
|
139 |
+
nhead=encoder_nhead,
|
140 |
+
num_layers=num_encoder_layers,
|
141 |
+
dim_feedforward=1024,
|
142 |
+
dropout=dropout
|
143 |
+
)
|
144 |
+
self.decoder = TransformerDecoder(
|
145 |
+
d_model=d_model,
|
146 |
+
nhead=decoder_nhead,
|
147 |
+
num_layers=num_decoder_layers,
|
148 |
+
dim_feedforward=dim_feedforward,
|
149 |
+
dropout=dropout
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
|
153 |
+
memory = self.encoder(Z)
|
154 |
+
Z1 = self.decoder(memory, Z1_start_tokens, teacher_forcing_ratio=teacher_forcing_ratio)
|
155 |
+
return Z1
|
modules/.ipynb_checkpoints/segmentface-checkpoint.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import mediapipe as mp
|
3 |
+
import numpy as np
|
4 |
+
from rembg import remove
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
class FaceSegmenter:
|
8 |
+
def __init__(self, threshold=0.5):
|
9 |
+
self.threshold = threshold
|
10 |
+
# Initialize face detection
|
11 |
+
self.face_detection = mp.solutions.face_detection.FaceDetection(
|
12 |
+
model_selection=1, # 1 for general use, 0 for close-up faces
|
13 |
+
min_detection_confidence=0.5
|
14 |
+
)
|
15 |
+
# Initialize selfie segmentation (for background removal)
|
16 |
+
self.selfie_segmentation = mp.solutions.selfie_segmentation.SelfieSegmentation(
|
17 |
+
model_selection=1 # 1 for general use, 0 for close-up faces
|
18 |
+
)
|
19 |
+
|
20 |
+
def segment_face(self, image_path):
|
21 |
+
# Load the image
|
22 |
+
image = cv2.imread(image_path)
|
23 |
+
if image is None:
|
24 |
+
raise ValueError("Image not found or unable to load.")
|
25 |
+
|
26 |
+
# Convert to RGB (MediaPipe requires RGB input)
|
27 |
+
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
28 |
+
|
29 |
+
# Step 1: Detect the face
|
30 |
+
face_results = self.face_detection.process(rgb_image)
|
31 |
+
if not face_results.detections:
|
32 |
+
# Use rembg to remove the background
|
33 |
+
with open(image_path, "rb") as input_file:
|
34 |
+
input_image = input_file.read()
|
35 |
+
output_image = remove(input_image)
|
36 |
+
# Convert the output image to a numpy array
|
37 |
+
output_image = np.array(Image.open(io.BytesIO(output_image)))
|
38 |
+
# Convert RGBA to RGB (remove alpha channel)
|
39 |
+
if output_image.shape[2] == 4:
|
40 |
+
output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
|
41 |
+
return output_image
|
42 |
+
|
43 |
+
# Get the bounding box of the first detected face
|
44 |
+
detection = face_results.detections[0]
|
45 |
+
bboxC = detection.location_data.relative_bounding_box
|
46 |
+
h, w, _ = image.shape
|
47 |
+
x, y, width, height = int(bboxC.xmin * w), int(bboxC.ymin * h), \
|
48 |
+
int(bboxC.width * w), int(bboxC.height * h)
|
49 |
+
|
50 |
+
# Step 2: Segment the foreground (selfie segmentation)
|
51 |
+
segmentation_results = self.selfie_segmentation.process(rgb_image)
|
52 |
+
if segmentation_results.segmentation_mask is None:
|
53 |
+
raise ValueError("Segmentation failed.")
|
54 |
+
|
55 |
+
# Create a binary mask
|
56 |
+
mask = (segmentation_results.segmentation_mask > self.threshold).astype(np.uint8)
|
57 |
+
|
58 |
+
# Step 3: Crop the face using the bounding box
|
59 |
+
face_mask = np.zeros_like(mask)
|
60 |
+
face_mask[y:y+height, x:x+width] = mask[y:y+height, x:x+width]
|
61 |
+
|
62 |
+
# Apply the mask to the original image
|
63 |
+
segmented_face = cv2.bitwise_and(image, image, mask=face_mask)
|
64 |
+
|
65 |
+
return segmented_face
|
66 |
+
|
67 |
+
def save_segmented_face(self, image_path, output_path):
|
68 |
+
segmented_face = self.segment_face(image_path)
|
69 |
+
cv2.imwrite(output_path, segmented_face)
|
70 |
+
|
71 |
+
def show_segmented_face(self, image_path):
|
72 |
+
segmented_face = self.segment_face(image_path)
|
73 |
+
cv2.imshow("Segmented Face", segmented_face)
|
74 |
+
cv2.waitKey(0)
|
75 |
+
cv2.destroyAllWindows()
|
modules/denormalize.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
#------------------Denormalization---------------------------------------------
|
5 |
+
def denormalize_bin(tensor):
|
6 |
+
tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
|
7 |
+
tr = tr.add(1).div(2) # Shift to [0, 1]
|
8 |
+
return tr
|
9 |
+
|
10 |
+
def denormalize_tr(tensor):
|
11 |
+
tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
|
12 |
+
tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255]
|
13 |
+
tr = tr.byte() # Convert the tensor to uint8
|
14 |
+
return tr
|
15 |
+
|
16 |
+
def denormalize_ar(tensor):
|
17 |
+
tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1
|
18 |
+
tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255]
|
19 |
+
tr = tr.byte() # Convert the tensor to uint8
|
20 |
+
arr = tr.permute(0, 2, 3, 1).cpu().detach().numpy() # Convert to (N, H, W, C) and numpy array
|
21 |
+
return arr
|
modules/finetunedvqgan.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.checkpoint import checkpoint
|
3 |
+
from taming.models.vqgan import VQModel
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from taming.models.vqgan import GumbelVQ
|
6 |
+
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
|
9 |
+
class Generator:
|
10 |
+
def __init__(self, config_path, device=device):
|
11 |
+
self.config_path = config_path
|
12 |
+
self.device = device
|
13 |
+
|
14 |
+
def load_models(self):
|
15 |
+
# Load configuration
|
16 |
+
config = OmegaConf.load(self.config_path)
|
17 |
+
# Extract parameters specific to GumbelVQ
|
18 |
+
vq_params = config.model.params
|
19 |
+
# Initialize the GumbelVQ models
|
20 |
+
model_vaq = GumbelVQ(
|
21 |
+
ddconfig=vq_params.ddconfig,
|
22 |
+
lossconfig=vq_params.lossconfig,
|
23 |
+
n_embed=vq_params.n_embed,
|
24 |
+
embed_dim=vq_params.embed_dim,
|
25 |
+
kl_weight=vq_params.kl_weight,
|
26 |
+
temperature_scheduler_config=vq_params.temperature_scheduler_config,
|
27 |
+
).to(self.device)
|
28 |
+
|
29 |
+
return model_vaq
|
30 |
+
|
31 |
+
|
modules/frameworkeval.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.models import vgg16
|
5 |
+
from torchmetrics.functional import structural_similarity_index_measure
|
6 |
+
from facenet_pytorch import InceptionResnetV1
|
7 |
+
from denormalize import denormalize_bin, denormalize_tr, denormalize_ar
|
8 |
+
|
9 |
+
class DF(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super(DF, self).__init__()
|
12 |
+
self.mse_weight = 0.25
|
13 |
+
self.perceptual_weight = 0.25
|
14 |
+
self.ssim_weight = 0.25
|
15 |
+
self.idsim_weight = 0.25
|
16 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
|
18 |
+
self.facenet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
|
19 |
+
for param in self.facenet.parameters():
|
20 |
+
param.requires_grad = False # Freeze the model
|
21 |
+
self.cosloss = nn.CosineEmbeddingLoss()
|
22 |
+
|
23 |
+
def perceptual_loss(self, real, fake):
|
24 |
+
with torch.no_grad(): # VGG is frozen during training
|
25 |
+
real_features = self.vgg(real)
|
26 |
+
fake_features = self.vgg(fake)
|
27 |
+
return F.mse_loss(real_features, fake_features)
|
28 |
+
|
29 |
+
def idsimilarity(self, real, fake):
|
30 |
+
with torch.no_grad():
|
31 |
+
# Extract embeddings
|
32 |
+
input_embed = self.facenet(real).to(device)
|
33 |
+
generated_embed = self.facenet(fake).to(device)
|
34 |
+
# Compute cosine similarity loss
|
35 |
+
target = torch.ones(input_embed.size(0)).to(real.device) # Target = 1 (maximize similarity)
|
36 |
+
return self.cosloss(input_embed, generated_embed, target)
|
37 |
+
|
38 |
+
def forward(self, r, f):
|
39 |
+
real = denormalize_bin(r) #[-1,1] to [0,1]
|
40 |
+
fake = denormalize_bin(f)
|
41 |
+
mse_loss = F.mse_loss(real, fake)
|
42 |
+
perceptual_loss = self.perceptual_loss(real, fake)
|
43 |
+
idsim_loss = self.idsimilarity(real, fake)
|
44 |
+
ssim = structural_similarity_index_measure(fake, real)
|
45 |
+
ssim_loss = 1 - ssim
|
46 |
+
id_si = 1 - idsim_loss
|
47 |
+
|
48 |
+
total_loss = (self.mse_weight * mse_loss) + (self.perceptual_weight * perceptual_loss) + (self.idsim_weight * idsim_loss) + (self.ssim_weight * ssim_loss)
|
49 |
+
components = {
|
50 |
+
"MSE Loss": mse_loss.item(),
|
51 |
+
"Perceptual Loss": perceptual_loss.item(),
|
52 |
+
"ID-SIM Loss": idsim_loss.item(),
|
53 |
+
"SSIM Loss": ssim_loss.item()
|
54 |
+
}
|
55 |
+
|
56 |
+
return total_loss, components
|
modules/modelz.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
##_____________________Define:MODEL-F & MODEL-G_________________
|
8 |
+
|
9 |
+
# Positional Encoding
|
10 |
+
class PositionalEncoding(nn.Module):
|
11 |
+
def __init__(self, d_model, max_len=1024):
|
12 |
+
super(PositionalEncoding, self).__init__()
|
13 |
+
self.dropout = nn.Dropout(0.1)
|
14 |
+
position = torch.arange(max_len).unsqueeze(1)
|
15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
16 |
+
pe = torch.zeros(max_len, d_model)
|
17 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
18 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
19 |
+
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x + self.pe[:, :x.size(1)]
|
23 |
+
return self.dropout(x)
|
24 |
+
|
25 |
+
# Transformer Encoder
|
26 |
+
class TransformerEncoder(nn.Module):
|
27 |
+
def __init__(self, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1):
|
28 |
+
super(TransformerEncoder, self).__init__()
|
29 |
+
self.positional_encoding = PositionalEncoding(d_model)
|
30 |
+
self.encoder = nn.TransformerEncoder(
|
31 |
+
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout,batch_first=True),
|
32 |
+
num_layers=num_layers
|
33 |
+
)
|
34 |
+
|
35 |
+
def preprocess_latent(self, Z):
|
36 |
+
batch_size, channels, height, width = Z.shape # (batch_size, 256, 32, 32)
|
37 |
+
seq_len = height * width
|
38 |
+
Z = Z.permute(0, 2, 3, 1).reshape(batch_size, seq_len, channels) # (batch_size, 1024, 256)
|
39 |
+
return Z
|
40 |
+
|
41 |
+
def postprocess_latent(self, Z):
|
42 |
+
batch_size, seq_len, channels = Z.shape # (batch_size, 1024, 256)
|
43 |
+
height = width = int(math.sqrt(seq_len))
|
44 |
+
Z = Z.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) # (batch_size, 256, 32, 32)
|
45 |
+
return Z
|
46 |
+
|
47 |
+
def forward(self, Z):
|
48 |
+
Z = self.preprocess_latent(Z)
|
49 |
+
Z = self.positional_encoding(Z)
|
50 |
+
Z = self.encoder(Z)
|
51 |
+
Z = self.postprocess_latent(Z)
|
52 |
+
return Z # latent of transformer
|
53 |
+
|
54 |
+
class TransformerDecoder(nn.Module):
|
55 |
+
def __init__(self, d_model=256, nhead=8, num_layers=12, dim_feedforward=1024, dropout=0.1):
|
56 |
+
super().__init__()
|
57 |
+
self.d_model = d_model
|
58 |
+
|
59 |
+
# Enhanced positional encoding
|
60 |
+
self.positional_encoding = PositionalEncoding(d_model)
|
61 |
+
|
62 |
+
# Multi-layer learnable start tokens
|
63 |
+
self.base_start = nn.Parameter(torch.randn(1, 1024, d_model))
|
64 |
+
self.start_net = nn.Sequential(
|
65 |
+
nn.LayerNorm(d_model),
|
66 |
+
nn.Linear(d_model, dim_feedforward),
|
67 |
+
nn.GELU(),
|
68 |
+
nn.Dropout(dropout),
|
69 |
+
nn.Linear(dim_feedforward, d_model),
|
70 |
+
nn.LayerNorm(d_model)
|
71 |
+
)
|
72 |
+
|
73 |
+
# Context-aware transformer decoder
|
74 |
+
self.decoder = nn.TransformerDecoder(
|
75 |
+
nn.TransformerDecoderLayer(
|
76 |
+
d_model=d_model,
|
77 |
+
nhead=nhead,
|
78 |
+
dim_feedforward=dim_feedforward,
|
79 |
+
dropout=dropout,
|
80 |
+
batch_first=True
|
81 |
+
),
|
82 |
+
num_layers=num_layers
|
83 |
+
)
|
84 |
+
|
85 |
+
# Output projection with residual
|
86 |
+
self.output_layer = nn.Sequential(
|
87 |
+
nn.Linear(d_model, d_model*2),
|
88 |
+
nn.GELU(),
|
89 |
+
nn.Linear(d_model*2, d_model))
|
90 |
+
|
91 |
+
self.init_weights()
|
92 |
+
|
93 |
+
def init_weights(self):
|
94 |
+
for p in self.parameters():
|
95 |
+
if p.dim() > 1:
|
96 |
+
nn.init.xavier_uniform_(p)
|
97 |
+
nn.init.normal_(self.base_start, mean=0, std=0.02)
|
98 |
+
|
99 |
+
def preprocess_latent(self, Z):
|
100 |
+
# Convert (B, C, H, W) to (B, H*W, C)
|
101 |
+
return Z.permute(0, 2, 3, 1).flatten(1, 2)
|
102 |
+
|
103 |
+
def postprocess_latent(self, Z):
|
104 |
+
# Convert (B, H*W, C) back to (B, C, H, W)
|
105 |
+
B, L, C = Z.shape
|
106 |
+
H = W = int(L**0.5)
|
107 |
+
return Z.view(B, H, W, C).permute(0, 3, 1, 2)
|
108 |
+
|
109 |
+
def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
|
110 |
+
# Process input latent
|
111 |
+
Z = self.preprocess_latent(Z)
|
112 |
+
#Z = self.positional_encoding(Z)
|
113 |
+
|
114 |
+
# Generate enhanced start tokens
|
115 |
+
B = Z.size(0)
|
116 |
+
base_tokens = self.base_start.expand(B, -1, -1)
|
117 |
+
processed_start = self.start_net(base_tokens)
|
118 |
+
|
119 |
+
# Teacher forcing integration
|
120 |
+
if Z1_start_tokens is not None and teacher_forcing_ratio > 0:
|
121 |
+
Z1_processed = self.positional_encoding(self.preprocess_latent(Z1_start_tokens))
|
122 |
+
|
123 |
+
# Create mixing mask
|
124 |
+
mask = torch.rand(B, 1, 1, device=Z.device) < teacher_forcing_ratio
|
125 |
+
processed_start = torch.where(mask, Z1_processed, processed_start)
|
126 |
+
|
127 |
+
# Decoder processing with residual
|
128 |
+
decoder_input = self.positional_encoding(processed_start)
|
129 |
+
outputs = self.decoder(decoder_input, Z)
|
130 |
+
outputs = self.output_layer(outputs + decoder_input)
|
131 |
+
|
132 |
+
return self.postprocess_latent(outputs)
|
133 |
+
|
134 |
+
class DeepfakeToSourceTransformer(nn.Module):
|
135 |
+
def __init__(self, d_model=256, encoder_nhead=8, decoder_nhead=8, num_encoder_layers=6, num_decoder_layers=12, dim_feedforward=1024, dropout=0.1):
|
136 |
+
super().__init__()
|
137 |
+
self.encoder = TransformerEncoder(
|
138 |
+
d_model=d_model,
|
139 |
+
nhead=encoder_nhead,
|
140 |
+
num_layers=num_encoder_layers,
|
141 |
+
dim_feedforward=1024,
|
142 |
+
dropout=dropout
|
143 |
+
)
|
144 |
+
self.decoder = TransformerDecoder(
|
145 |
+
d_model=d_model,
|
146 |
+
nhead=decoder_nhead,
|
147 |
+
num_layers=num_decoder_layers,
|
148 |
+
dim_feedforward=dim_feedforward,
|
149 |
+
dropout=dropout
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
|
153 |
+
memory = self.encoder(Z)
|
154 |
+
Z1 = self.decoder(memory, Z1_start_tokens, teacher_forcing_ratio=teacher_forcing_ratio)
|
155 |
+
return Z1
|
modules/segmentface.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import mediapipe as mp
|
3 |
+
import numpy as np
|
4 |
+
from rembg import remove
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
class FaceSegmenter:
|
8 |
+
def __init__(self, threshold=0.5):
|
9 |
+
self.threshold = threshold
|
10 |
+
# Initialize face detection
|
11 |
+
self.face_detection = mp.solutions.face_detection.FaceDetection(
|
12 |
+
model_selection=1, # 1 for general use, 0 for close-up faces
|
13 |
+
min_detection_confidence=0.5
|
14 |
+
)
|
15 |
+
# Initialize selfie segmentation (for background removal)
|
16 |
+
self.selfie_segmentation = mp.solutions.selfie_segmentation.SelfieSegmentation(
|
17 |
+
model_selection=1 # 1 for general use, 0 for close-up faces
|
18 |
+
)
|
19 |
+
|
20 |
+
def segment_face(self, image_path):
|
21 |
+
# Load the image
|
22 |
+
image = cv2.imread(image_path)
|
23 |
+
if image is None:
|
24 |
+
raise ValueError("Image not found or unable to load.")
|
25 |
+
|
26 |
+
# Convert to RGB (MediaPipe requires RGB input)
|
27 |
+
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
28 |
+
|
29 |
+
# Step 1: Detect the face
|
30 |
+
face_results = self.face_detection.process(rgb_image)
|
31 |
+
if not face_results.detections:
|
32 |
+
# Use rembg to remove the background
|
33 |
+
with open(image_path, "rb") as input_file:
|
34 |
+
input_image = input_file.read()
|
35 |
+
output_image = remove(input_image)
|
36 |
+
# Convert the output image to a numpy array
|
37 |
+
output_image = np.array(Image.open(io.BytesIO(output_image)))
|
38 |
+
# Convert RGBA to RGB (remove alpha channel)
|
39 |
+
if output_image.shape[2] == 4:
|
40 |
+
output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
|
41 |
+
return output_image
|
42 |
+
|
43 |
+
# Get the bounding box of the first detected face
|
44 |
+
detection = face_results.detections[0]
|
45 |
+
bboxC = detection.location_data.relative_bounding_box
|
46 |
+
h, w, _ = image.shape
|
47 |
+
x, y, width, height = int(bboxC.xmin * w), int(bboxC.ymin * h), \
|
48 |
+
int(bboxC.width * w), int(bboxC.height * h)
|
49 |
+
|
50 |
+
# Step 2: Segment the foreground (selfie segmentation)
|
51 |
+
segmentation_results = self.selfie_segmentation.process(rgb_image)
|
52 |
+
if segmentation_results.segmentation_mask is None:
|
53 |
+
raise ValueError("Segmentation failed.")
|
54 |
+
|
55 |
+
# Create a binary mask
|
56 |
+
mask = (segmentation_results.segmentation_mask > self.threshold).astype(np.uint8)
|
57 |
+
|
58 |
+
# Step 3: Crop the face using the bounding box
|
59 |
+
face_mask = np.zeros_like(mask)
|
60 |
+
face_mask[y:y+height, x:x+width] = mask[y:y+height, x:x+width]
|
61 |
+
|
62 |
+
# Apply the mask to the original image
|
63 |
+
segmented_face = cv2.bitwise_and(image, image, mask=face_mask)
|
64 |
+
|
65 |
+
return segmented_face
|
66 |
+
|
67 |
+
def save_segmented_face(self, image_path, output_path):
|
68 |
+
segmented_face = self.segment_face(image_path)
|
69 |
+
cv2.imwrite(output_path, segmented_face)
|
70 |
+
|
71 |
+
def show_segmented_face(self, image_path):
|
72 |
+
segmented_face = self.segment_face(image_path)
|
73 |
+
cv2.imshow("Segmented Face", segmented_face)
|
74 |
+
cv2.waitKey(0)
|
75 |
+
cv2.destroyAllWindows()
|