Jannat24 commited on
Commit
b3cd85c
·
1 Parent(s): bbb5f33
Files changed (3) hide show
  1. finetunedvqgan.py +29 -0
  2. modelz.py +155 -0
  3. segmentface.py +75 -0
finetunedvqgan.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.checkpoint import checkpoint
3
+ from taming.models.vqgan import VQModel
4
+ from omegaconf import OmegaConf
5
+ from taming.models.vqgan import GumbelVQ
6
+
7
+ class Generator:
8
+ def __init__(self, config_path, device=device):
9
+ self.config_path = config_path
10
+ self.device = device
11
+
12
+ def load_models(self):
13
+ # Load configuration
14
+ config = OmegaConf.load(self.config_path)
15
+ # Extract parameters specific to GumbelVQ
16
+ vq_params = config.model.params
17
+ # Initialize the GumbelVQ models
18
+ model_vaq = GumbelVQ(
19
+ ddconfig=vq_params.ddconfig,
20
+ lossconfig=vq_params.lossconfig,
21
+ n_embed=vq_params.n_embed,
22
+ embed_dim=vq_params.embed_dim,
23
+ kl_weight=vq_params.kl_weight,
24
+ temperature_scheduler_config=vq_params.temperature_scheduler_config,
25
+ ).to(self.device)
26
+
27
+ return model_vaq
28
+
29
+
modelz.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ ##_____________________Define:MODEL-F & MODEL-G_________________
8
+
9
+ # Positional Encoding
10
+ class PositionalEncoding(nn.Module):
11
+ def __init__(self, d_model, max_len=1024):
12
+ super(PositionalEncoding, self).__init__()
13
+ self.dropout = nn.Dropout(0.1)
14
+ position = torch.arange(max_len).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16
+ pe = torch.zeros(max_len, d_model)
17
+ pe[:, 0::2] = torch.sin(position * div_term)
18
+ pe[:, 1::2] = torch.cos(position * div_term)
19
+ self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
20
+
21
+ def forward(self, x):
22
+ x = x + self.pe[:, :x.size(1)]
23
+ return self.dropout(x)
24
+
25
+ # Transformer Encoder
26
+ class TransformerEncoder(nn.Module):
27
+ def __init__(self, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1):
28
+ super(TransformerEncoder, self).__init__()
29
+ self.positional_encoding = PositionalEncoding(d_model)
30
+ self.encoder = nn.TransformerEncoder(
31
+ nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout,batch_first=True),
32
+ num_layers=num_layers
33
+ )
34
+
35
+ def preprocess_latent(self, Z):
36
+ batch_size, channels, height, width = Z.shape # (batch_size, 256, 32, 32)
37
+ seq_len = height * width
38
+ Z = Z.permute(0, 2, 3, 1).reshape(batch_size, seq_len, channels) # (batch_size, 1024, 256)
39
+ return Z
40
+
41
+ def postprocess_latent(self, Z):
42
+ batch_size, seq_len, channels = Z.shape # (batch_size, 1024, 256)
43
+ height = width = int(math.sqrt(seq_len))
44
+ Z = Z.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) # (batch_size, 256, 32, 32)
45
+ return Z
46
+
47
+ def forward(self, Z):
48
+ Z = self.preprocess_latent(Z)
49
+ Z = self.positional_encoding(Z)
50
+ Z = self.encoder(Z)
51
+ Z = self.postprocess_latent(Z)
52
+ return Z # latent of transformer
53
+
54
+ class TransformerDecoder(nn.Module):
55
+ def __init__(self, d_model=256, nhead=8, num_layers=12, dim_feedforward=1024, dropout=0.1):
56
+ super().__init__()
57
+ self.d_model = d_model
58
+
59
+ # Enhanced positional encoding
60
+ self.positional_encoding = PositionalEncoding(d_model)
61
+
62
+ # Multi-layer learnable start tokens
63
+ self.base_start = nn.Parameter(torch.randn(1, 1024, d_model))
64
+ self.start_net = nn.Sequential(
65
+ nn.LayerNorm(d_model),
66
+ nn.Linear(d_model, dim_feedforward),
67
+ nn.GELU(),
68
+ nn.Dropout(dropout),
69
+ nn.Linear(dim_feedforward, d_model),
70
+ nn.LayerNorm(d_model)
71
+ )
72
+
73
+ # Context-aware transformer decoder
74
+ self.decoder = nn.TransformerDecoder(
75
+ nn.TransformerDecoderLayer(
76
+ d_model=d_model,
77
+ nhead=nhead,
78
+ dim_feedforward=dim_feedforward,
79
+ dropout=dropout,
80
+ batch_first=True
81
+ ),
82
+ num_layers=num_layers
83
+ )
84
+
85
+ # Output projection with residual
86
+ self.output_layer = nn.Sequential(
87
+ nn.Linear(d_model, d_model*2),
88
+ nn.GELU(),
89
+ nn.Linear(d_model*2, d_model))
90
+
91
+ self.init_weights()
92
+
93
+ def init_weights(self):
94
+ for p in self.parameters():
95
+ if p.dim() > 1:
96
+ nn.init.xavier_uniform_(p)
97
+ nn.init.normal_(self.base_start, mean=0, std=0.02)
98
+
99
+ def preprocess_latent(self, Z):
100
+ # Convert (B, C, H, W) to (B, H*W, C)
101
+ return Z.permute(0, 2, 3, 1).flatten(1, 2)
102
+
103
+ def postprocess_latent(self, Z):
104
+ # Convert (B, H*W, C) back to (B, C, H, W)
105
+ B, L, C = Z.shape
106
+ H = W = int(L**0.5)
107
+ return Z.view(B, H, W, C).permute(0, 3, 1, 2)
108
+
109
+ def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
110
+ # Process input latent
111
+ Z = self.preprocess_latent(Z)
112
+ #Z = self.positional_encoding(Z)
113
+
114
+ # Generate enhanced start tokens
115
+ B = Z.size(0)
116
+ base_tokens = self.base_start.expand(B, -1, -1)
117
+ processed_start = self.start_net(base_tokens)
118
+
119
+ # Teacher forcing integration
120
+ if Z1_start_tokens is not None and teacher_forcing_ratio > 0:
121
+ Z1_processed = self.positional_encoding(self.preprocess_latent(Z1_start_tokens))
122
+
123
+ # Create mixing mask
124
+ mask = torch.rand(B, 1, 1, device=Z.device) < teacher_forcing_ratio
125
+ processed_start = torch.where(mask, Z1_processed, processed_start)
126
+
127
+ # Decoder processing with residual
128
+ decoder_input = self.positional_encoding(processed_start)
129
+ outputs = self.decoder(decoder_input, Z)
130
+ outputs = self.output_layer(outputs + decoder_input)
131
+
132
+ return self.postprocess_latent(outputs)
133
+
134
+ class DeepfakeToSourceTransformer(nn.Module):
135
+ def __init__(self, d_model=256, encoder_nhead=8, decoder_nhead=8, num_encoder_layers=6, num_decoder_layers=12, dim_feedforward=1024, dropout=0.1):
136
+ super().__init__()
137
+ self.encoder = TransformerEncoder(
138
+ d_model=d_model,
139
+ nhead=encoder_nhead,
140
+ num_layers=num_encoder_layers,
141
+ dim_feedforward=1024,
142
+ dropout=dropout
143
+ )
144
+ self.decoder = TransformerDecoder(
145
+ d_model=d_model,
146
+ nhead=decoder_nhead,
147
+ num_layers=num_decoder_layers,
148
+ dim_feedforward=dim_feedforward,
149
+ dropout=dropout
150
+ )
151
+
152
+ def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5):
153
+ memory = self.encoder(Z)
154
+ Z1 = self.decoder(memory, Z1_start_tokens, teacher_forcing_ratio=teacher_forcing_ratio)
155
+ return Z1
segmentface.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import mediapipe as mp
3
+ import numpy as np
4
+ from rembg import remove
5
+ from PIL import Image
6
+
7
+ class FaceSegmenter:
8
+ def __init__(self, threshold=0.5):
9
+ self.threshold = threshold
10
+ # Initialize face detection
11
+ self.face_detection = mp.solutions.face_detection.FaceDetection(
12
+ model_selection=1, # 1 for general use, 0 for close-up faces
13
+ min_detection_confidence=0.5
14
+ )
15
+ # Initialize selfie segmentation (for background removal)
16
+ self.selfie_segmentation = mp.solutions.selfie_segmentation.SelfieSegmentation(
17
+ model_selection=1 # 1 for general use, 0 for close-up faces
18
+ )
19
+
20
+ def segment_face(self, image_path):
21
+ # Load the image
22
+ image = cv2.imread(image_path)
23
+ if image is None:
24
+ raise ValueError("Image not found or unable to load.")
25
+
26
+ # Convert to RGB (MediaPipe requires RGB input)
27
+ rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
28
+
29
+ # Step 1: Detect the face
30
+ face_results = self.face_detection.process(rgb_image)
31
+ if not face_results.detections:
32
+ # Use rembg to remove the background
33
+ with open(image_path, "rb") as input_file:
34
+ input_image = input_file.read()
35
+ output_image = remove(input_image)
36
+ # Convert the output image to a numpy array
37
+ output_image = np.array(Image.open(io.BytesIO(output_image)))
38
+ # Convert RGBA to RGB (remove alpha channel)
39
+ if output_image.shape[2] == 4:
40
+ output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
41
+ return output_image
42
+
43
+ # Get the bounding box of the first detected face
44
+ detection = face_results.detections[0]
45
+ bboxC = detection.location_data.relative_bounding_box
46
+ h, w, _ = image.shape
47
+ x, y, width, height = int(bboxC.xmin * w), int(bboxC.ymin * h), \
48
+ int(bboxC.width * w), int(bboxC.height * h)
49
+
50
+ # Step 2: Segment the foreground (selfie segmentation)
51
+ segmentation_results = self.selfie_segmentation.process(rgb_image)
52
+ if segmentation_results.segmentation_mask is None:
53
+ raise ValueError("Segmentation failed.")
54
+
55
+ # Create a binary mask
56
+ mask = (segmentation_results.segmentation_mask > self.threshold).astype(np.uint8)
57
+
58
+ # Step 3: Crop the face using the bounding box
59
+ face_mask = np.zeros_like(mask)
60
+ face_mask[y:y+height, x:x+width] = mask[y:y+height, x:x+width]
61
+
62
+ # Apply the mask to the original image
63
+ segmented_face = cv2.bitwise_and(image, image, mask=face_mask)
64
+
65
+ return segmented_face
66
+
67
+ def save_segmented_face(self, image_path, output_path):
68
+ segmented_face = self.segment_face(image_path)
69
+ cv2.imwrite(output_path, segmented_face)
70
+
71
+ def show_segmented_face(self, image_path):
72
+ segmented_face = self.segment_face(image_path)
73
+ cv2.imshow("Segmented Face", segmented_face)
74
+ cv2.waitKey(0)
75
+ cv2.destroyAllWindows()