Spaces:
Sleeping
Sleeping
allmodels
Browse files- finetunedvqgan.py +29 -0
- modelz.py +155 -0
- segmentface.py +75 -0
finetunedvqgan.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.checkpoint import checkpoint
|
3 |
+
from taming.models.vqgan import VQModel
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from taming.models.vqgan import GumbelVQ
|
6 |
+
|
7 |
+
class Generator:
|
8 |
+
def __init__(self, config_path, device=device):
|
9 |
+
self.config_path = config_path
|
10 |
+
self.device = device
|
11 |
+
|
12 |
+
def load_models(self):
|
13 |
+
# Load configuration
|
14 |
+
config = OmegaConf.load(self.config_path)
|
15 |
+
# Extract parameters specific to GumbelVQ
|
16 |
+
vq_params = config.model.params
|
17 |
+
# Initialize the GumbelVQ models
|
18 |
+
model_vaq = GumbelVQ(
|
19 |
+
ddconfig=vq_params.ddconfig,
|
20 |
+
lossconfig=vq_params.lossconfig,
|
21 |
+
n_embed=vq_params.n_embed,
|
22 |
+
embed_dim=vq_params.embed_dim,
|
23 |
+
kl_weight=vq_params.kl_weight,
|
24 |
+
temperature_scheduler_config=vq_params.temperature_scheduler_config,
|
25 |
+
).to(self.device)
|
26 |
+
|
27 |
+
return model_vaq
|
28 |
+
|
29 |
+
|
modelz.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
##_____________________Define:MODEL-F & MODEL-G_________________
|
8 |
+
|
9 |
+
# Positional Encoding
|
10 |
+
class PositionalEncoding(nn.Module):
|
11 |
+
def __init__(self, d_model, max_len=1024):
|
12 |
+
super(PositionalEncoding, self).__init__()
|
13 |
+
self.dropout = nn.Dropout(0.1)
|
14 |
+
position = torch.arange(max_len).unsqueeze(1)
|
15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
16 |
+
pe = torch.zeros(max_len, d_model)
|
17 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
18 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
19 |
+
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x + self.pe[:, :x.size(1)]
|
23 |
+
return self.dropout(x)
|
24 |
+
|
25 |
+
# Transformer Encoder
|
26 |
+
class TransformerEncoder(nn.Module):
|
27 |
+
def __init__(self, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1):
|
28 |
+
super(TransformerEncoder, self).__init__()
|
29 |
+
self.positional_encoding = PositionalEncoding(d_model)
|
30 |
+
self.encoder = nn.TransformerEncoder(
|
31 |
+
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout,batch_first=True),
|
32 |
+
num_layers=num_layers
|
33 |
+
)
|
34 |
+
|
35 |
+
def preprocess_latent(self, Z):
|
36 |
+
batch_size, channels, height, width = Z.shape # (batch_size, 256, 32, 32)
|
37 |
+
seq_len = height * width
|
38 |
+
Z = Z.permute(0, 2, 3, 1).reshape(batch_size, seq_len, channels) # (batch_size, 1024, 256)
|
39 |
+
return Z
|
40 |
+
|
41 |
+
def postprocess_latent(self, Z):
|
42 |
+
batch_size, seq_len, channels = Z.shape # (batch_size, 1024, 256)
|
43 |
+
height = width = int(math.sqrt(seq_len))
|
44 |
+
Z = Z.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) # (batch_size, 256, 32, 32)
|
45 |
+
return Z
|
46 |
+
|
47 |
+
def forward(self, Z):
|
48 |
+
Z = self.preprocess_latent(Z)
|
49 |
+
Z = self.positional_encoding(Z)
|
50 |
+
Z = self.encoder(Z)
|
51 |
+
Z = self.postprocess_latent(Z)
|
52 |
+
return Z # latent of transformer
|
53 |
+
|
54 |
+
class TransformerDecoder(nn.Module):
|
55 |
+
def __init__(self, d_model=256, nhead=8, num_layers=12, dim_feedforward=1024, dropout=0.1):
|
56 |
+
super().__init__()
|
57 |
+
self.d_model = d_model
|
58 |
+
|
59 |
+
# Enhanced positional encoding
|
60 |
+
self.positional_encoding = PositionalEncoding(d_model)
|
61 |
+
|
62 |
+
# Multi-layer learnable start tokens
|
63 |
+
self.base_start = nn.Parameter(torch.randn(1, 1024, d_model))
|
64 |
+
self.start_net = nn.Sequential(
|
65 |
+
nn.LayerNorm(d_model),
|
66 |
+
nn.Linear(d_model, dim_feedforward),
|
67 |
+
nn.GELU(),
|
68 |
+
nn.Dropout(dropout),
|
69 |
+
nn.Linear(dim_feedforward, d_model),
|
70 |
+
nn.LayerNorm(d_model)
|
71 |
+
)
|
72 |
+
|
73 |
+
# Context-aware transformer decoder
|
74 |
+
self.decoder = nn.TransformerDecoder(
|
75 |
+
nn.TransformerDecoderLayer(
|
76 |
+
d_model=d_model,
|
77 |
+
nhead=nhead,
|
78 |
+
dim_feedforward=dim_feedforward,
|
79 |
+
dropout=dropout,
|
80 |
+
batch_first=True
|
81 |
+
),
|
82 |
+
num_layers=num_layers
|
83 |
+
)
|
84 |
+
|
85 |
+
# Output projection with residual
|
86 |
+
self.output_layer = nn.Sequential(
|
87 |
+
nn.Linear(d_model, d_model*2),
|
88 |
+
nn.GELU(),
|
89 |
+
nn.Linear(d_model*2, d_model))
|
90 |
+
|
91 |
+
self.init_weights()
|
92 |
+
|
93 |
+
def init_weights(self):
|
94 |
+
for p in self.parameters():
|
95 |
+
if p.dim() > 1:
|
96 |
+
nn.init.xavier_uniform_(p)
|
97 |
+
nn.init.normal_(self.base_start, mean=0, std=0.02)
|
98 |
+
|
99 |
+
def preprocess_latent(self, Z):
|
100 |
+
# Convert (B, C, H, W) to (B, H*W, C)
|
101 |
+
return Z.permute(0, 2, 3, 1).flatten(1, 2)
|
102 |
+
|
103 |
+
def postprocess_latent(self, Z):
|
104 |
+
# Convert (B, H*W, C) back to (B, C, H, W)
|
105 |
+
B, L, C = Z.shape
|
106 |
+
H = W = int(L**0.5)
|
107 |
+
return Z.view(B, H, W, C).permute(0, 3, 1, 2)
|
108 |
+
|
109 |
+
def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
|
110 |
+
# Process input latent
|
111 |
+
Z = self.preprocess_latent(Z)
|
112 |
+
#Z = self.positional_encoding(Z)
|
113 |
+
|
114 |
+
# Generate enhanced start tokens
|
115 |
+
B = Z.size(0)
|
116 |
+
base_tokens = self.base_start.expand(B, -1, -1)
|
117 |
+
processed_start = self.start_net(base_tokens)
|
118 |
+
|
119 |
+
# Teacher forcing integration
|
120 |
+
if Z1_start_tokens is not None and teacher_forcing_ratio > 0:
|
121 |
+
Z1_processed = self.positional_encoding(self.preprocess_latent(Z1_start_tokens))
|
122 |
+
|
123 |
+
# Create mixing mask
|
124 |
+
mask = torch.rand(B, 1, 1, device=Z.device) < teacher_forcing_ratio
|
125 |
+
processed_start = torch.where(mask, Z1_processed, processed_start)
|
126 |
+
|
127 |
+
# Decoder processing with residual
|
128 |
+
decoder_input = self.positional_encoding(processed_start)
|
129 |
+
outputs = self.decoder(decoder_input, Z)
|
130 |
+
outputs = self.output_layer(outputs + decoder_input)
|
131 |
+
|
132 |
+
return self.postprocess_latent(outputs)
|
133 |
+
|
134 |
+
class DeepfakeToSourceTransformer(nn.Module):
|
135 |
+
def __init__(self, d_model=256, encoder_nhead=8, decoder_nhead=8, num_encoder_layers=6, num_decoder_layers=12, dim_feedforward=1024, dropout=0.1):
|
136 |
+
super().__init__()
|
137 |
+
self.encoder = TransformerEncoder(
|
138 |
+
d_model=d_model,
|
139 |
+
nhead=encoder_nhead,
|
140 |
+
num_layers=num_encoder_layers,
|
141 |
+
dim_feedforward=1024,
|
142 |
+
dropout=dropout
|
143 |
+
)
|
144 |
+
self.decoder = TransformerDecoder(
|
145 |
+
d_model=d_model,
|
146 |
+
nhead=decoder_nhead,
|
147 |
+
num_layers=num_decoder_layers,
|
148 |
+
dim_feedforward=dim_feedforward,
|
149 |
+
dropout=dropout
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
|
153 |
+
memory = self.encoder(Z)
|
154 |
+
Z1 = self.decoder(memory, Z1_start_tokens, teacher_forcing_ratio=teacher_forcing_ratio)
|
155 |
+
return Z1
|
segmentface.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import mediapipe as mp
|
3 |
+
import numpy as np
|
4 |
+
from rembg import remove
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
class FaceSegmenter:
|
8 |
+
def __init__(self, threshold=0.5):
|
9 |
+
self.threshold = threshold
|
10 |
+
# Initialize face detection
|
11 |
+
self.face_detection = mp.solutions.face_detection.FaceDetection(
|
12 |
+
model_selection=1, # 1 for general use, 0 for close-up faces
|
13 |
+
min_detection_confidence=0.5
|
14 |
+
)
|
15 |
+
# Initialize selfie segmentation (for background removal)
|
16 |
+
self.selfie_segmentation = mp.solutions.selfie_segmentation.SelfieSegmentation(
|
17 |
+
model_selection=1 # 1 for general use, 0 for close-up faces
|
18 |
+
)
|
19 |
+
|
20 |
+
def segment_face(self, image_path):
|
21 |
+
# Load the image
|
22 |
+
image = cv2.imread(image_path)
|
23 |
+
if image is None:
|
24 |
+
raise ValueError("Image not found or unable to load.")
|
25 |
+
|
26 |
+
# Convert to RGB (MediaPipe requires RGB input)
|
27 |
+
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
28 |
+
|
29 |
+
# Step 1: Detect the face
|
30 |
+
face_results = self.face_detection.process(rgb_image)
|
31 |
+
if not face_results.detections:
|
32 |
+
# Use rembg to remove the background
|
33 |
+
with open(image_path, "rb") as input_file:
|
34 |
+
input_image = input_file.read()
|
35 |
+
output_image = remove(input_image)
|
36 |
+
# Convert the output image to a numpy array
|
37 |
+
output_image = np.array(Image.open(io.BytesIO(output_image)))
|
38 |
+
# Convert RGBA to RGB (remove alpha channel)
|
39 |
+
if output_image.shape[2] == 4:
|
40 |
+
output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
|
41 |
+
return output_image
|
42 |
+
|
43 |
+
# Get the bounding box of the first detected face
|
44 |
+
detection = face_results.detections[0]
|
45 |
+
bboxC = detection.location_data.relative_bounding_box
|
46 |
+
h, w, _ = image.shape
|
47 |
+
x, y, width, height = int(bboxC.xmin * w), int(bboxC.ymin * h), \
|
48 |
+
int(bboxC.width * w), int(bboxC.height * h)
|
49 |
+
|
50 |
+
# Step 2: Segment the foreground (selfie segmentation)
|
51 |
+
segmentation_results = self.selfie_segmentation.process(rgb_image)
|
52 |
+
if segmentation_results.segmentation_mask is None:
|
53 |
+
raise ValueError("Segmentation failed.")
|
54 |
+
|
55 |
+
# Create a binary mask
|
56 |
+
mask = (segmentation_results.segmentation_mask > self.threshold).astype(np.uint8)
|
57 |
+
|
58 |
+
# Step 3: Crop the face using the bounding box
|
59 |
+
face_mask = np.zeros_like(mask)
|
60 |
+
face_mask[y:y+height, x:x+width] = mask[y:y+height, x:x+width]
|
61 |
+
|
62 |
+
# Apply the mask to the original image
|
63 |
+
segmented_face = cv2.bitwise_and(image, image, mask=face_mask)
|
64 |
+
|
65 |
+
return segmented_face
|
66 |
+
|
67 |
+
def save_segmented_face(self, image_path, output_path):
|
68 |
+
segmented_face = self.segment_face(image_path)
|
69 |
+
cv2.imwrite(output_path, segmented_face)
|
70 |
+
|
71 |
+
def show_segmented_face(self, image_path):
|
72 |
+
segmented_face = self.segment_face(image_path)
|
73 |
+
cv2.imshow("Segmented Face", segmented_face)
|
74 |
+
cv2.waitKey(0)
|
75 |
+
cv2.destroyAllWindows()
|