WebashalarForML commited on
Commit
0899cd0
·
verified ·
1 Parent(s): 3bc7327

Update inference2.py

Browse files
Files changed (1) hide show
  1. inference2.py +345 -345
inference2.py CHANGED
@@ -1,346 +1,346 @@
1
- # inference.py (Updated)
2
-
3
- from os import listdir, path
4
- import numpy as np
5
- import scipy, cv2, os, sys, argparse, audio
6
- import json, subprocess, random, string
7
- from tqdm import tqdm
8
- from glob import glob
9
- import torch # Ensure torch is imported
10
- try:
11
- import face_detection # Assuming this is installed or in a path accessible by your Flask app
12
- except ImportError:
13
- print("face_detection not found. Please ensure it's installed or available in your PYTHONPATH.")
14
- # You might want to raise an error or handle this gracefully if face_detection is truly optional.
15
-
16
- # Make sure you have a models/Wav2Lip.py or similar structure
17
- try:
18
- from models import Wav2Lip
19
- except ImportError:
20
- print("Wav2Lip model not found. Please ensure models/Wav2Lip.py exists and is correctly configured.")
21
- # You might want to raise an error or handle this gracefully.
22
-
23
- import platform
24
- import shutil # For clearing temp directory
25
-
26
-
27
- # These globals are still useful for shared configuration
28
- mel_step_size = 16
29
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
- print('Inference script using {} for inference.'.format(device))
31
-
32
-
33
- def get_smoothened_boxes(boxes, T):
34
- for i in range(len(boxes)):
35
- if i + T > len(boxes):
36
- window = boxes[len(boxes) - T:]
37
- else:
38
- window = boxes[i : i + T]
39
- boxes[i] = np.mean(window, axis=0)
40
- return boxes
41
-
42
- def face_detect(images, pads, face_det_batch_size, nosmooth, img_size):
43
- detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
44
- flip_input=False, device=device)
45
-
46
- batch_size = face_det_batch_size
47
-
48
- while 1:
49
- predictions = []
50
- try:
51
- for i in tqdm(range(0, len(images), batch_size), desc="Face Detection"):
52
- predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
53
- except RuntimeError as e:
54
- if batch_size == 1:
55
- raise RuntimeError(f'Image too big to run face detection on GPU. Error: {e}')
56
- batch_size //= 2
57
- print('Recovering from OOM error; New face detection batch size: {}'.format(batch_size))
58
- continue
59
- break
60
-
61
- results = []
62
- pady1, pady2, padx1, padx2 = pads
63
- for rect, image in zip(predictions, images):
64
- if rect is None:
65
- # Save the faulty frame for debugging
66
- output_dir = 'temp' # Ensure this exists or create it
67
- os.makedirs(output_dir, exist_ok=True)
68
- cv2.imwrite(os.path.join(output_dir, 'faulty_frame.jpg'), image)
69
- raise ValueError('Face not detected! Ensure the video/image contains a face in all the frames or try adjusting pads/box.')
70
-
71
- y1 = max(0, rect[1] - pady1)
72
- y2 = min(image.shape[0], rect[3] + pady2)
73
- x1 = max(0, rect[0] - padx1)
74
- x2 = min(image.shape[1], rect[2] + padx2)
75
-
76
- results.append([x1, y1, x2, y2])
77
-
78
- boxes = np.array(results)
79
- if not nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
80
- results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
81
-
82
- del detector # Clean up detector
83
- return results
84
-
85
- def datagen(frames, mels, box, static, wav2lip_batch_size, img_size, pads, face_det_batch_size, nosmooth):
86
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
87
-
88
- if box[0] == -1:
89
- if not static:
90
- face_det_results = face_detect(frames, pads, face_det_batch_size, nosmooth, img_size) # BGR2RGB for CNN face detection
91
- else:
92
- face_det_results = face_detect([frames[0]], pads, face_det_batch_size, nosmooth, img_size)
93
- else:
94
- print('Using the specified bounding box instead of face detection...')
95
- y1, y2, x1, x2 = box
96
- face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
97
-
98
- for i, m in enumerate(mels):
99
- idx = 0 if static else i % len(frames)
100
- frame_to_save = frames[idx].copy()
101
- face, coords = face_det_results[idx].copy()
102
-
103
- face = cv2.resize(face, (img_size, img_size))
104
-
105
- img_batch.append(face)
106
- mel_batch.append(m)
107
- frame_batch.append(frame_to_save)
108
- coords_batch.append(coords)
109
-
110
- if len(img_batch) >= wav2lip_batch_size:
111
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
112
-
113
- img_masked = img_batch.copy()
114
- img_masked[:, img_size//2:] = 0
115
-
116
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
117
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
118
-
119
- yield img_batch, mel_batch, frame_batch, coords_batch
120
- img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
121
-
122
- if len(img_batch) > 0:
123
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
124
-
125
- img_masked = img_batch.copy()
126
- img_masked[:, img_size//2:] = 0
127
-
128
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
129
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
130
-
131
- yield img_batch, mel_batch, frame_batch, coords_batch
132
-
133
- def _load(checkpoint_path):
134
- # Use torch.jit.load for TorchScript archives
135
- if device == 'cuda':
136
- model = torch.jit.load(checkpoint_path)
137
- else:
138
- # Accepts string or torch.device, not a lambda
139
- model = torch.jit.load(checkpoint_path, map_location='cpu')
140
- return model
141
-
142
- def load_model(path):
143
- print("Loading scripted model from:", path)
144
- model = _load(path) # returns the TorchScript Module
145
- model = model.to(device) # move to CPU or GPU
146
- return model.eval() # set to eval() mode
147
-
148
-
149
- # New function to be called from Flask app
150
- def run_inference(
151
- checkpoint_path: str,
152
- face_path: str,
153
- audio_path: str,
154
- output_filename: str,
155
- static: bool = False,
156
- fps: float = 25.,
157
- pads: list = [0, 10, 0, 0],
158
- face_det_batch_size: int = 16,
159
- wav2lip_batch_size: int = 128,
160
- resize_factor: int = 1,
161
- crop: list = [0, -1, 0, -1],
162
- box: list = [-1, -1, -1, -1],
163
- rotate: bool = False,
164
- nosmooth: bool = False,
165
- img_size: int = 96 # Fixed for Wav2Lip
166
- ) -> str:
167
- """
168
- Runs the Wav2Lip inference process.
169
-
170
- Args:
171
- checkpoint_path (str): Path to the Wav2Lip model checkpoint.
172
- face_path (str): Path to the input video/image file with a face.
173
- audio_path (str): Path to the input audio file.
174
- output_filename (str): Name of the output video file (e.g., 'result.mp4').
175
- static (bool): If True, use only the first video frame for inference.
176
- fps (float): Frames per second for static image input.
177
- pads (list): Padding for face detection (top, bottom, left, right).
178
- face_det_batch_size (int): Batch size for face detection.
179
- wav2lip_batch_size (int): Batch size for Wav2Lip model(s).
180
- resize_factor (int): Reduce the resolution by this factor.
181
- crop (list): Crop video to a smaller region (top, bottom, left, right).
182
- box (list): Constant bounding box for the face.
183
- rotate (bool): Rotate video right by 90deg.
184
- nosmooth (bool): Prevent smoothing face detections.
185
- img_size (int): Image size for the model.
186
-
187
- Returns:
188
- str: The path to the generated output video file.
189
- """
190
- print(f"Starting inference with: face='{face_path}', audio='{audio_path}', checkpoint='{checkpoint_path}', outfile='{output_filename}'")
191
-
192
- # Create necessary directories
193
- output_dir = 'results'
194
- temp_dir = 'temp'
195
- os.makedirs(output_dir, exist_ok=True)
196
- os.makedirs(temp_dir, exist_ok=True)
197
-
198
- # Clear temp directory for fresh run
199
- for item in os.listdir(temp_dir):
200
- item_path = os.path.join(temp_dir, item)
201
- if os.path.isfile(item_path):
202
- os.remove(item_path)
203
- elif os.path.isdir(item_path):
204
- shutil.rmtree(item_path)
205
-
206
- # Determine if input is static based on file extension
207
- is_static_input = static or (os.path.isfile(face_path) and face_path.split('.')[-1].lower() in ['jpg', 'png', 'jpeg'])
208
-
209
- full_frames = []
210
- if is_static_input:
211
- full_frames = [cv2.imread(face_path)]
212
- if full_frames[0] is None:
213
- raise ValueError(f"Could not read face image at: {face_path}")
214
- else:
215
- video_stream = cv2.VideoCapture(face_path)
216
- if not video_stream.isOpened():
217
- raise ValueError(f"Could not open video file at: {face_path}")
218
- fps = video_stream.get(cv2.CAP_PROP_FPS)
219
-
220
- print('Reading video frames...')
221
- while 1:
222
- still_reading, frame = video_stream.read()
223
- if not still_reading:
224
- video_stream.release()
225
- break
226
- if resize_factor > 1:
227
- frame = cv2.resize(frame, (frame.shape[1]//resize_factor, frame.shape[0]//resize_factor))
228
-
229
- if rotate:
230
- frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
231
-
232
- y1, y2, x1, x2 = crop
233
- if x2 == -1: x2 = frame.shape[1]
234
- if y2 == -1: y2 = frame.shape[0]
235
-
236
- frame = frame[y1:y2, x1:x2]
237
- full_frames.append(frame)
238
-
239
- print ("Number of frames available for inference: "+str(len(full_frames)))
240
- if not full_frames:
241
- raise ValueError("No frames could be read from the input face file.")
242
-
243
- temp_audio_path = os.path.join(temp_dir, 'temp_audio.wav')
244
- if not audio_path.endswith('.wav'):
245
- print('Extracting raw audio...')
246
- command = f'ffmpeg -y -i "{audio_path}" -strict -2 "{temp_audio_path}"'
247
- try:
248
- subprocess.run(command, shell=True, check=True, capture_output=True)
249
- audio_path = temp_audio_path
250
- except subprocess.CalledProcessError as e:
251
- print(f"FFmpeg error: {e.stderr.decode()}")
252
- raise RuntimeError(f"Failed to extract audio from {audio_path}. Error: {e.stderr.decode()}")
253
- else:
254
- # Copy the wav file to temp if it's already wav to maintain consistency in naming
255
- shutil.copy(audio_path, temp_audio_path)
256
- audio_path = temp_audio_path
257
-
258
-
259
- wav = audio.load_wav(audio_path, 16000)
260
- mel = audio.melspectrogram(wav)
261
- print("Mel spectrogram shape:", mel.shape)
262
-
263
- if np.isnan(mel.reshape(-1)).sum() > 0:
264
- raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
265
-
266
- mel_chunks = []
267
- mel_idx_multiplier = 80./fps
268
- i = 0
269
- while 1:
270
- start_idx = int(i * mel_idx_multiplier)
271
- if start_idx + mel_step_size > len(mel[0]):
272
- mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
273
- break
274
- mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
275
- i += 1
276
-
277
- print("Length of mel chunks: {}".format(len(mel_chunks)))
278
-
279
- # Ensure full_frames matches mel_chunks length, or loop if static
280
- if not is_static_input:
281
- full_frames = full_frames[:len(mel_chunks)]
282
- else:
283
- # If static, replicate the first frame for the duration of the audio
284
- full_frames = [full_frames[0]] * len(mel_chunks)
285
-
286
-
287
- gen = datagen(full_frames.copy(), mel_chunks, box, is_static_input, wav2lip_batch_size, img_size, pads, face_det_batch_size, nosmooth)
288
-
289
- output_avi_path = os.path.join(temp_dir, 'result.avi')
290
-
291
- model_loaded = False
292
- model = None
293
- frame_h, frame_w = 0, 0
294
- out = None
295
-
296
- for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, desc="Wav2Lip Inference",
297
- total=int(np.ceil(float(len(mel_chunks))/wav2lip_batch_size)))):
298
- if not model_loaded:
299
- model = load_model(checkpoint_path)
300
- model_loaded = True
301
- print ("Model loaded successfully")
302
-
303
- frame_h, frame_w = full_frames[0].shape[:-1]
304
- out = cv2.VideoWriter(output_avi_path,
305
- cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
306
- if out is None: # In case no frames were generated for some reason
307
- raise RuntimeError("Video writer could not be initialized.")
308
-
309
-
310
- img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
311
- mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
312
-
313
- with torch.no_grad():
314
- pred = model(mel_batch, img_batch)
315
-
316
- pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
317
-
318
- for p, f, c in zip(pred, frames, coords):
319
- y1, y2, x1, x2 = c
320
- p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
321
-
322
- f[y1:y2, x1:x2] = p
323
- out.write(f)
324
-
325
- if out:
326
- out.release()
327
- else:
328
- print("Warning: Video writer was not initialized or no frames were processed.")
329
-
330
-
331
- final_output_path = os.path.join(output_dir, output_filename)
332
- command = f'ffmpeg -y -i "{audio_path}" -i "{output_avi_path}" -strict -2 -q:v 1 "{final_output_path}"'
333
-
334
- try:
335
- subprocess.run(command, shell=True, check=True, capture_output=True)
336
- print(f"Output saved to: {final_output_path}")
337
- except subprocess.CalledProcessError as e:
338
- print(f"FFmpeg final merge error: {e.stderr.decode()}")
339
- raise RuntimeError(f"Failed to merge audio and video. Error: {e.stderr.decode()}")
340
-
341
- # Clean up temporary files (optional, but good practice)
342
- # shutil.rmtree(temp_dir) # Be careful with this if you want to inspect temp files
343
-
344
- return final_output_path
345
-
346
  # No `if __name__ == '__main__':` block here, as it's meant to be imported
 
1
+ # inference.py (Updated)
2
+ import audio
3
+ from os import listdir, path
4
+ import numpy as np
5
+ import scipy, cv2, os, sys, argparse, audio
6
+ import json, subprocess, random, string
7
+ from tqdm import tqdm
8
+ from glob import glob
9
+ import torch # Ensure torch is imported
10
+ try:
11
+ import face_detection # Assuming this is installed or in a path accessible by your Flask app
12
+ except ImportError:
13
+ print("face_detection not found. Please ensure it's installed or available in your PYTHONPATH.")
14
+ # You might want to raise an error or handle this gracefully if face_detection is truly optional.
15
+
16
+ # Make sure you have a models/Wav2Lip.py or similar structure
17
+ try:
18
+ from models import Wav2Lip
19
+ except ImportError:
20
+ print("Wav2Lip model not found. Please ensure models/Wav2Lip.py exists and is correctly configured.")
21
+ # You might want to raise an error or handle this gracefully.
22
+
23
+ import platform
24
+ import shutil # For clearing temp directory
25
+
26
+
27
+ # These globals are still useful for shared configuration
28
+ mel_step_size = 16
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+ print('Inference script using {} for inference.'.format(device))
31
+
32
+
33
+ def get_smoothened_boxes(boxes, T):
34
+ for i in range(len(boxes)):
35
+ if i + T > len(boxes):
36
+ window = boxes[len(boxes) - T:]
37
+ else:
38
+ window = boxes[i : i + T]
39
+ boxes[i] = np.mean(window, axis=0)
40
+ return boxes
41
+
42
+ def face_detect(images, pads, face_det_batch_size, nosmooth, img_size):
43
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
44
+ flip_input=False, device=device)
45
+
46
+ batch_size = face_det_batch_size
47
+
48
+ while 1:
49
+ predictions = []
50
+ try:
51
+ for i in tqdm(range(0, len(images), batch_size), desc="Face Detection"):
52
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
53
+ except RuntimeError as e:
54
+ if batch_size == 1:
55
+ raise RuntimeError(f'Image too big to run face detection on GPU. Error: {e}')
56
+ batch_size //= 2
57
+ print('Recovering from OOM error; New face detection batch size: {}'.format(batch_size))
58
+ continue
59
+ break
60
+
61
+ results = []
62
+ pady1, pady2, padx1, padx2 = pads
63
+ for rect, image in zip(predictions, images):
64
+ if rect is None:
65
+ # Save the faulty frame for debugging
66
+ output_dir = 'temp' # Ensure this exists or create it
67
+ os.makedirs(output_dir, exist_ok=True)
68
+ cv2.imwrite(os.path.join(output_dir, 'faulty_frame.jpg'), image)
69
+ raise ValueError('Face not detected! Ensure the video/image contains a face in all the frames or try adjusting pads/box.')
70
+
71
+ y1 = max(0, rect[1] - pady1)
72
+ y2 = min(image.shape[0], rect[3] + pady2)
73
+ x1 = max(0, rect[0] - padx1)
74
+ x2 = min(image.shape[1], rect[2] + padx2)
75
+
76
+ results.append([x1, y1, x2, y2])
77
+
78
+ boxes = np.array(results)
79
+ if not nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
80
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
81
+
82
+ del detector # Clean up detector
83
+ return results
84
+
85
+ def datagen(frames, mels, box, static, wav2lip_batch_size, img_size, pads, face_det_batch_size, nosmooth):
86
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
87
+
88
+ if box[0] == -1:
89
+ if not static:
90
+ face_det_results = face_detect(frames, pads, face_det_batch_size, nosmooth, img_size) # BGR2RGB for CNN face detection
91
+ else:
92
+ face_det_results = face_detect([frames[0]], pads, face_det_batch_size, nosmooth, img_size)
93
+ else:
94
+ print('Using the specified bounding box instead of face detection...')
95
+ y1, y2, x1, x2 = box
96
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
97
+
98
+ for i, m in enumerate(mels):
99
+ idx = 0 if static else i % len(frames)
100
+ frame_to_save = frames[idx].copy()
101
+ face, coords = face_det_results[idx].copy()
102
+
103
+ face = cv2.resize(face, (img_size, img_size))
104
+
105
+ img_batch.append(face)
106
+ mel_batch.append(m)
107
+ frame_batch.append(frame_to_save)
108
+ coords_batch.append(coords)
109
+
110
+ if len(img_batch) >= wav2lip_batch_size:
111
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
112
+
113
+ img_masked = img_batch.copy()
114
+ img_masked[:, img_size//2:] = 0
115
+
116
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
117
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
118
+
119
+ yield img_batch, mel_batch, frame_batch, coords_batch
120
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
121
+
122
+ if len(img_batch) > 0:
123
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
124
+
125
+ img_masked = img_batch.copy()
126
+ img_masked[:, img_size//2:] = 0
127
+
128
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
129
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
130
+
131
+ yield img_batch, mel_batch, frame_batch, coords_batch
132
+
133
+ def _load(checkpoint_path):
134
+ # Use torch.jit.load for TorchScript archives
135
+ if device == 'cuda':
136
+ model = torch.jit.load(checkpoint_path)
137
+ else:
138
+ # Accepts string or torch.device, not a lambda
139
+ model = torch.jit.load(checkpoint_path, map_location='cpu')
140
+ return model
141
+
142
+ def load_model(path):
143
+ print("Loading scripted model from:", path)
144
+ model = _load(path) # returns the TorchScript Module
145
+ model = model.to(device) # move to CPU or GPU
146
+ return model.eval() # set to eval() mode
147
+
148
+
149
+ # New function to be called from Flask app
150
+ def run_inference(
151
+ checkpoint_path: str,
152
+ face_path: str,
153
+ audio_path: str,
154
+ output_filename: str,
155
+ static: bool = False,
156
+ fps: float = 25.,
157
+ pads: list = [0, 10, 0, 0],
158
+ face_det_batch_size: int = 16,
159
+ wav2lip_batch_size: int = 128,
160
+ resize_factor: int = 1,
161
+ crop: list = [0, -1, 0, -1],
162
+ box: list = [-1, -1, -1, -1],
163
+ rotate: bool = False,
164
+ nosmooth: bool = False,
165
+ img_size: int = 96 # Fixed for Wav2Lip
166
+ ) -> str:
167
+ """
168
+ Runs the Wav2Lip inference process.
169
+
170
+ Args:
171
+ checkpoint_path (str): Path to the Wav2Lip model checkpoint.
172
+ face_path (str): Path to the input video/image file with a face.
173
+ audio_path (str): Path to the input audio file.
174
+ output_filename (str): Name of the output video file (e.g., 'result.mp4').
175
+ static (bool): If True, use only the first video frame for inference.
176
+ fps (float): Frames per second for static image input.
177
+ pads (list): Padding for face detection (top, bottom, left, right).
178
+ face_det_batch_size (int): Batch size for face detection.
179
+ wav2lip_batch_size (int): Batch size for Wav2Lip model(s).
180
+ resize_factor (int): Reduce the resolution by this factor.
181
+ crop (list): Crop video to a smaller region (top, bottom, left, right).
182
+ box (list): Constant bounding box for the face.
183
+ rotate (bool): Rotate video right by 90deg.
184
+ nosmooth (bool): Prevent smoothing face detections.
185
+ img_size (int): Image size for the model.
186
+
187
+ Returns:
188
+ str: The path to the generated output video file.
189
+ """
190
+ print(f"Starting inference with: face='{face_path}', audio='{audio_path}', checkpoint='{checkpoint_path}', outfile='{output_filename}'")
191
+
192
+ # Create necessary directories
193
+ output_dir = 'results'
194
+ temp_dir = 'temp'
195
+ os.makedirs(output_dir, exist_ok=True)
196
+ os.makedirs(temp_dir, exist_ok=True)
197
+
198
+ # Clear temp directory for fresh run
199
+ for item in os.listdir(temp_dir):
200
+ item_path = os.path.join(temp_dir, item)
201
+ if os.path.isfile(item_path):
202
+ os.remove(item_path)
203
+ elif os.path.isdir(item_path):
204
+ shutil.rmtree(item_path)
205
+
206
+ # Determine if input is static based on file extension
207
+ is_static_input = static or (os.path.isfile(face_path) and face_path.split('.')[-1].lower() in ['jpg', 'png', 'jpeg'])
208
+
209
+ full_frames = []
210
+ if is_static_input:
211
+ full_frames = [cv2.imread(face_path)]
212
+ if full_frames[0] is None:
213
+ raise ValueError(f"Could not read face image at: {face_path}")
214
+ else:
215
+ video_stream = cv2.VideoCapture(face_path)
216
+ if not video_stream.isOpened():
217
+ raise ValueError(f"Could not open video file at: {face_path}")
218
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
219
+
220
+ print('Reading video frames...')
221
+ while 1:
222
+ still_reading, frame = video_stream.read()
223
+ if not still_reading:
224
+ video_stream.release()
225
+ break
226
+ if resize_factor > 1:
227
+ frame = cv2.resize(frame, (frame.shape[1]//resize_factor, frame.shape[0]//resize_factor))
228
+
229
+ if rotate:
230
+ frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
231
+
232
+ y1, y2, x1, x2 = crop
233
+ if x2 == -1: x2 = frame.shape[1]
234
+ if y2 == -1: y2 = frame.shape[0]
235
+
236
+ frame = frame[y1:y2, x1:x2]
237
+ full_frames.append(frame)
238
+
239
+ print ("Number of frames available for inference: "+str(len(full_frames)))
240
+ if not full_frames:
241
+ raise ValueError("No frames could be read from the input face file.")
242
+
243
+ temp_audio_path = os.path.join(temp_dir, 'temp_audio.wav')
244
+ if not audio_path.endswith('.wav'):
245
+ print('Extracting raw audio...')
246
+ command = f'ffmpeg -y -i "{audio_path}" -strict -2 "{temp_audio_path}"'
247
+ try:
248
+ subprocess.run(command, shell=True, check=True, capture_output=True)
249
+ audio_path = temp_audio_path
250
+ except subprocess.CalledProcessError as e:
251
+ print(f"FFmpeg error: {e.stderr.decode()}")
252
+ raise RuntimeError(f"Failed to extract audio from {audio_path}. Error: {e.stderr.decode()}")
253
+ else:
254
+ # Copy the wav file to temp if it's already wav to maintain consistency in naming
255
+ shutil.copy(audio_path, temp_audio_path)
256
+ audio_path = temp_audio_path
257
+
258
+
259
+ wav = audio.load_wav(audio_path, 16000)
260
+ mel = audio.melspectrogram(wav)
261
+ print("Mel spectrogram shape:", mel.shape)
262
+
263
+ if np.isnan(mel.reshape(-1)).sum() > 0:
264
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
265
+
266
+ mel_chunks = []
267
+ mel_idx_multiplier = 80./fps
268
+ i = 0
269
+ while 1:
270
+ start_idx = int(i * mel_idx_multiplier)
271
+ if start_idx + mel_step_size > len(mel[0]):
272
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
273
+ break
274
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
275
+ i += 1
276
+
277
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
278
+
279
+ # Ensure full_frames matches mel_chunks length, or loop if static
280
+ if not is_static_input:
281
+ full_frames = full_frames[:len(mel_chunks)]
282
+ else:
283
+ # If static, replicate the first frame for the duration of the audio
284
+ full_frames = [full_frames[0]] * len(mel_chunks)
285
+
286
+
287
+ gen = datagen(full_frames.copy(), mel_chunks, box, is_static_input, wav2lip_batch_size, img_size, pads, face_det_batch_size, nosmooth)
288
+
289
+ output_avi_path = os.path.join(temp_dir, 'result.avi')
290
+
291
+ model_loaded = False
292
+ model = None
293
+ frame_h, frame_w = 0, 0
294
+ out = None
295
+
296
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, desc="Wav2Lip Inference",
297
+ total=int(np.ceil(float(len(mel_chunks))/wav2lip_batch_size)))):
298
+ if not model_loaded:
299
+ model = load_model(checkpoint_path)
300
+ model_loaded = True
301
+ print ("Model loaded successfully")
302
+
303
+ frame_h, frame_w = full_frames[0].shape[:-1]
304
+ out = cv2.VideoWriter(output_avi_path,
305
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
306
+ if out is None: # In case no frames were generated for some reason
307
+ raise RuntimeError("Video writer could not be initialized.")
308
+
309
+
310
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
311
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
312
+
313
+ with torch.no_grad():
314
+ pred = model(mel_batch, img_batch)
315
+
316
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
317
+
318
+ for p, f, c in zip(pred, frames, coords):
319
+ y1, y2, x1, x2 = c
320
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
321
+
322
+ f[y1:y2, x1:x2] = p
323
+ out.write(f)
324
+
325
+ if out:
326
+ out.release()
327
+ else:
328
+ print("Warning: Video writer was not initialized or no frames were processed.")
329
+
330
+
331
+ final_output_path = os.path.join(output_dir, output_filename)
332
+ command = f'ffmpeg -y -i "{audio_path}" -i "{output_avi_path}" -strict -2 -q:v 1 "{final_output_path}"'
333
+
334
+ try:
335
+ subprocess.run(command, shell=True, check=True, capture_output=True)
336
+ print(f"Output saved to: {final_output_path}")
337
+ except subprocess.CalledProcessError as e:
338
+ print(f"FFmpeg final merge error: {e.stderr.decode()}")
339
+ raise RuntimeError(f"Failed to merge audio and video. Error: {e.stderr.decode()}")
340
+
341
+ # Clean up temporary files (optional, but good practice)
342
+ # shutil.rmtree(temp_dir) # Be careful with this if you want to inspect temp files
343
+
344
+ return final_output_path
345
+
346
  # No `if __name__ == '__main__':` block here, as it's meant to be imported