File size: 13,951 Bytes
0899cd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868d3da
 
0899cd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faa8ce9
0899cd0
 
 
 
faa8ce9
 
0899cd0
 
faa8ce9
0899cd0
 
 
 
 
 
 
 
 
 
faa8ce9
0899cd0
868d3da
 
faa8ce9
 
 
 
 
0899cd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c51a38f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# inference.py (Updated)
from os import listdir, path
import numpy as np
import scipy, cv2, os, sys, argparse, audio
import json, subprocess, random, string
from tqdm import tqdm
from glob import glob
import torch # Ensure torch is imported
try:
    import face_detection # Assuming this is installed or in a path accessible by your Flask app
except ImportError:
    print("face_detection not found. Please ensure it's installed or available in your PYTHONPATH.")
    # You might want to raise an error or handle this gracefully if face_detection is truly optional.

# Make sure you have a models/Wav2Lip.py or similar structure
try:
    from models import Wav2Lip
except ImportError:
    print("Wav2Lip model not found. Please ensure models/Wav2Lip.py exists and is correctly configured.")
    # You might want to raise an error or handle this gracefully.

import platform
import shutil # For clearing temp directory


# These globals are still useful for shared configuration
mel_step_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Inference script using {} for inference.'.format(device))


def get_smoothened_boxes(boxes, T):
    for i in range(len(boxes)):
        if i + T > len(boxes):
            window = boxes[len(boxes) - T:]
        else:
            window = boxes[i : i + T]
        boxes[i] = np.mean(window, axis=0)
    return boxes

def face_detect(images, pads, face_det_batch_size, nosmooth, img_size):
    detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
                                            flip_input=False, device=device)

    batch_size = face_det_batch_size

    while 1:
        predictions = []
        try:
            for i in tqdm(range(0, len(images), batch_size), desc="Face Detection"):
                predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
        except RuntimeError as e:
            if batch_size == 1:
                raise RuntimeError(f'Image too big to run face detection on GPU. Error: {e}')
            batch_size //= 2
            print('Recovering from OOM error; New face detection batch size: {}'.format(batch_size))
            continue
        break

    results = []
    pady1, pady2, padx1, padx2 = pads
    for rect, image in zip(predictions, images):
        if rect is None:
            # Save the faulty frame for debugging
            output_dir = 'temp' # Ensure this exists or create it
            os.makedirs(output_dir, exist_ok=True)
            cv2.imwrite(os.path.join(output_dir, 'faulty_frame.jpg'), image)
            raise ValueError('Face not detected! Ensure the video/image contains a face in all the frames or try adjusting pads/box.')

        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], image.shape[1], rect[2] + padx2) # Corrected typo: image.shape[1] twice
        
        results.append([x1, y1, x2, y2])

    boxes = np.array(results)
    if not nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
    results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]

    del detector # Clean up detector
    return results

def datagen(frames, mels, box, static, wav2lip_batch_size, img_size, pads, face_det_batch_size, nosmooth):
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

    if box[0] == -1:
        if not static:
            face_det_results = face_detect(frames, pads, face_det_batch_size, nosmooth, img_size) # BGR2RGB for CNN face detection
        else:
            face_det_results = face_detect([frames[0]], pads, face_det_batch_size, nosmooth, img_size)
    else:
        print('Using the specified bounding box instead of face detection...')
        y1, y2, x1, x2 = box
        face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]

    for i, m in enumerate(mels):
        idx = 0 if static else i % len(frames)
        frame_to_save = frames[idx].copy()
        face, coords = face_det_results[idx].copy()

        face = cv2.resize(face, (img_size, img_size))

        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)

        if len(img_batch) >= wav2lip_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)

            img_masked = img_batch.copy()
            img_masked[:, img_size//2:] = 0

            img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

            yield img_batch, mel_batch, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)

        img_masked = img_batch.copy()
        img_masked[:, img_size//2:] = 0

        img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

        yield img_batch, mel_batch, frame_batch, coords_batch

def _load(checkpoint_path):
    # Use torch.jit.load for TorchScript archives
    if device == 'cuda':
        model = torch.jit.load(checkpoint_path)
    else:
        # Accepts string or torch.device, not a lambda
        model = torch.jit.load(checkpoint_path, map_location='cpu')
    return model

def load_model(path):
    print("Loading scripted model from:", path)
    model = _load(path) # returns the TorchScript Module
    model = model.to(device) # move to CPU or GPU
    return model.eval() # set to eval() mode


# New function to be called from Flask app
def run_inference(
    checkpoint_path: str,
    face_path: str,
    audio_path: str,
    output_filename: str,
    static: bool = False,
    fps: float = 25.,
    pads: list = [0, 10, 0, 0],
    face_det_batch_size: int = 16,
    wav2lip_batch_size: int = 128,
    resize_factor: int = 1,
    crop: list = [0, -1, 0, -1],
    box: list = [-1, -1, -1, -1],
    rotate: bool = False,
    nosmooth: bool = False,
    img_size: int = 96 # Fixed for Wav2Lip
) -> str:
    """
    Runs the Wav2Lip inference process.

    Args:
        checkpoint_path (str): Path to the Wav2Lip model checkpoint.
        face_path (str): Path to the input video/image file with a face.
        audio_path (str): Path to the input audio file.
        output_filename (str): Name of the output video file (e.g., 'result.mp4').
        static (bool): If True, use only the first video frame for inference.
        fps (float): Frames per second for static image input.
        pads (list): Padding for face detection (top, bottom, left, right).
        face_det_batch_size (int): Batch size for face detection.
        wav2lip_batch_size (int): Batch size for Wav2Lip model(s).
        resize_factor (int): Reduce the resolution by this factor.
        crop (list): Crop video to a smaller region (top, bottom, left, right).
        box (list): Constant bounding box for the face.
        rotate (bool): Rotate video right by 90deg.
        nosmooth (bool): Prevent smoothing face detections.
        img_size (int): Image size for the model.

    Returns:
        str: The path to the generated output video file.
    """
    print(f"Starting inference with: face='{face_path}', audio='{audio_path}', checkpoint='{checkpoint_path}', outfile='{output_filename}'")

    # Create necessary directories
    output_dir = 'results'
    temp_dir = 'temp'
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(temp_dir, exist_ok=True)

    # Clear temp directory for fresh run
    for item in os.listdir(temp_dir):
        item_path = os.path.join(temp_dir, item)
        if os.path.isfile(item_path):
            os.remove(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

    # Determine if input is static based on file extension
    is_static_input = static or (os.path.isfile(face_path) and face_path.split('.')[-1].lower() in ['jpg', 'png', 'jpeg'])

    full_frames = []
    if is_static_input:
        full_frames = [cv2.imread(face_path)]
        if full_frames[0] is None:
            raise ValueError(f"Could not read face image at: {face_path}")
    else:
        video_stream = cv2.VideoCapture(face_path)
        if not video_stream.isOpened():
            raise ValueError(f"Could not open video file at: {face_path}")
        fps = video_stream.get(cv2.CAP_PROP_FPS)

        print('Reading video frames...')
        while 1:
            still_reading, frame = video_stream.read()
            if not still_reading:
                video_stream.release()
                break
            if resize_factor > 1:
                frame = cv2.resize(frame, (frame.shape[1]//resize_factor, frame.shape[0]//resize_factor))

            if rotate:
                frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)

            y1, y2, x1, x2 = crop
            if x2 == -1: x2 = frame.shape[1]
            if y2 == -1: y2 = frame.shape[0]

            frame = frame[y1:y2, x1:x2]
            full_frames.append(frame)

    print("Number of frames available for inference:", len(full_frames))
    if not full_frames:
        raise ValueError("No frames could be read from the input face file.")

    temp_audio_path = os.path.join(temp_dir, 'temp_audio.wav')

    # Updated FFmpeg command: force mono, 16-bit, 16kHz
    if not audio_path.endswith('.wav'):
        print('Extracting raw audio...')
        command = f'ffmpeg -y -i "{audio_path}" -ac 1 -ar 16000 -sample_fmt s16 "{temp_audio_path}"'
        try:
            subprocess.run(command, shell=True, check=True, capture_output=True)
            audio_path = temp_audio_path
        except subprocess.CalledProcessError as e:
            print(f"FFmpeg error: {e.stderr.decode()}")
            raise RuntimeError(f"Failed to extract audio from {audio_path}. Error: {e.stderr.decode()}")
    else:
        shutil.copy(audio_path, temp_audio_path)
        audio_path = temp_audio_path

    # Load WAV audio
    wav = audio.load_wav(audio_path, 16000)
    wav = wav.astype(np.float32)

    # Check audio length
    print(f"Extracted audio samples: {len(wav)}, duration: {len(wav)/16000:.2f} sec")
    if len(wav) < 16000:
        raise ValueError(f"Audio is too short after conversion: only {len(wav)} samples. Please upload a longer clip.")

    mel = audio.melspectrogram(wav)
    print("Mel spectrogram shape:", mel.shape)

    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')

    mel_chunks = []
    mel_idx_multiplier = 80./fps
    i = 0
    while 1:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
        i += 1

    print("Length of mel chunks: {}".format(len(mel_chunks)))

    # Ensure full_frames matches mel_chunks length, or loop if static
    if not is_static_input:
        full_frames = full_frames[:len(mel_chunks)]
    else:
        # If static, replicate the first frame for the duration of the audio
        full_frames = [full_frames[0]] * len(mel_chunks)


    gen = datagen(full_frames.copy(), mel_chunks, box, is_static_input, wav2lip_batch_size, img_size, pads, face_det_batch_size, nosmooth)

    output_avi_path = os.path.join(temp_dir, 'result.avi')

    model_loaded = False
    model = None
    frame_h, frame_w = 0, 0
    out = None

    for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, desc="Wav2Lip Inference",
                                            total=int(np.ceil(float(len(mel_chunks))/wav2lip_batch_size)))):
        if not model_loaded:
            model = load_model(checkpoint_path)
            model_loaded = True
            print ("Model loaded successfully")

            frame_h, frame_w = full_frames[0].shape[:-1]
            out = cv2.VideoWriter(output_avi_path,
                                    cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
        if out is None: # In case no frames were generated for some reason
            raise RuntimeError("Video writer could not be initialized.")


        img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
        mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

        with torch.no_grad():
            pred = model(mel_batch, img_batch)

        pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.

        for p, f, c in zip(pred, frames, coords):
            y1, y2, x1, x2 = c
            p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))

            f[y1:y2, x1:x2] = p
            out.write(f)

    if out:
        out.release()
    else:
        print("Warning: Video writer was not initialized or no frames were processed.")


    final_output_path = os.path.join(output_dir, output_filename)
    command = f'ffmpeg -y -i "{audio_path}" -i "{output_avi_path}" -strict -2 -q:v 1 "{final_output_path}"'

    try:
        subprocess.run(command, shell=True, check=True, capture_output=True)
        print(f"Output saved to: {final_output_path}")
    except subprocess.CalledProcessError as e:
        print(f"FFmpeg final merge error: {e.stderr.decode()}")
        raise RuntimeError(f"Failed to merge audio and video. Error: {e.stderr.decode()}")

    # Clean up temporary files (optional, but good practice)
    # shutil.rmtree(temp_dir) # Be careful with this if you want to inspect temp files

    return final_output_path

# No `if __name__ == '__main__':` block here, as it's meant to be imported