Spaces:
Running
Running
# inference.py (Updated) | |
from os import listdir, path | |
import numpy as np | |
import scipy, cv2, os, sys, argparse, audio | |
import json, subprocess, random, string | |
from tqdm import tqdm | |
from glob import glob | |
import torch # Ensure torch is imported | |
try: | |
import face_detection # Assuming this is installed or in a path accessible by your Flask app | |
except ImportError: | |
print("face_detection not found. Please ensure it's installed or available in your PYTHONPATH.") | |
# You might want to raise an error or handle this gracefully if face_detection is truly optional. | |
# Make sure you have a models/Wav2Lip.py or similar structure | |
try: | |
from models import Wav2Lip | |
except ImportError: | |
print("Wav2Lip model not found. Please ensure models/Wav2Lip.py exists and is correctly configured.") | |
# You might want to raise an error or handle this gracefully. | |
import platform | |
import shutil # For clearing temp directory | |
# These globals are still useful for shared configuration | |
mel_step_size = 16 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print('Inference script using {} for inference.'.format(device)) | |
def get_smoothened_boxes(boxes, T): | |
for i in range(len(boxes)): | |
if i + T > len(boxes): | |
window = boxes[len(boxes) - T:] | |
else: | |
window = boxes[i : i + T] | |
boxes[i] = np.mean(window, axis=0) | |
return boxes | |
def face_detect(images, pads, face_det_batch_size, nosmooth, img_size): | |
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, | |
flip_input=False, device=device) | |
batch_size = face_det_batch_size | |
while 1: | |
predictions = [] | |
try: | |
for i in tqdm(range(0, len(images), batch_size), desc="Face Detection"): | |
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) | |
except RuntimeError as e: | |
if batch_size == 1: | |
raise RuntimeError(f'Image too big to run face detection on GPU. Error: {e}') | |
batch_size //= 2 | |
print('Recovering from OOM error; New face detection batch size: {}'.format(batch_size)) | |
continue | |
break | |
results = [] | |
pady1, pady2, padx1, padx2 = pads | |
for rect, image in zip(predictions, images): | |
if rect is None: | |
# Save the faulty frame for debugging | |
output_dir = 'temp' # Ensure this exists or create it | |
os.makedirs(output_dir, exist_ok=True) | |
cv2.imwrite(os.path.join(output_dir, 'faulty_frame.jpg'), image) | |
raise ValueError('Face not detected! Ensure the video/image contains a face in all the frames or try adjusting pads/box.') | |
y1 = max(0, rect[1] - pady1) | |
y2 = min(image.shape[0], rect[3] + pady2) | |
x1 = max(0, rect[0] - padx1) | |
x2 = min(image.shape[1], image.shape[1], rect[2] + padx2) # Corrected typo: image.shape[1] twice | |
results.append([x1, y1, x2, y2]) | |
boxes = np.array(results) | |
if not nosmooth: boxes = get_smoothened_boxes(boxes, T=5) | |
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] | |
del detector # Clean up detector | |
return results | |
def datagen(frames, mels, box, static, wav2lip_batch_size, img_size, pads, face_det_batch_size, nosmooth): | |
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] | |
if box[0] == -1: | |
if not static: | |
face_det_results = face_detect(frames, pads, face_det_batch_size, nosmooth, img_size) # BGR2RGB for CNN face detection | |
else: | |
face_det_results = face_detect([frames[0]], pads, face_det_batch_size, nosmooth, img_size) | |
else: | |
print('Using the specified bounding box instead of face detection...') | |
y1, y2, x1, x2 = box | |
face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] | |
for i, m in enumerate(mels): | |
idx = 0 if static else i % len(frames) | |
frame_to_save = frames[idx].copy() | |
face, coords = face_det_results[idx].copy() | |
face = cv2.resize(face, (img_size, img_size)) | |
img_batch.append(face) | |
mel_batch.append(m) | |
frame_batch.append(frame_to_save) | |
coords_batch.append(coords) | |
if len(img_batch) >= wav2lip_batch_size: | |
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) | |
img_masked = img_batch.copy() | |
img_masked[:, img_size//2:] = 0 | |
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. | |
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) | |
yield img_batch, mel_batch, frame_batch, coords_batch | |
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] | |
if len(img_batch) > 0: | |
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) | |
img_masked = img_batch.copy() | |
img_masked[:, img_size//2:] = 0 | |
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. | |
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) | |
yield img_batch, mel_batch, frame_batch, coords_batch | |
def _load(checkpoint_path): | |
# Use torch.jit.load for TorchScript archives | |
if device == 'cuda': | |
model = torch.jit.load(checkpoint_path) | |
else: | |
# Accepts string or torch.device, not a lambda | |
model = torch.jit.load(checkpoint_path, map_location='cpu') | |
return model | |
def load_model(path): | |
print("Loading scripted model from:", path) | |
model = _load(path) # returns the TorchScript Module | |
model = model.to(device) # move to CPU or GPU | |
return model.eval() # set to eval() mode | |
# New function to be called from Flask app | |
def run_inference( | |
checkpoint_path: str, | |
face_path: str, | |
audio_path: str, | |
output_filename: str, | |
static: bool = False, | |
fps: float = 25., | |
pads: list = [0, 10, 0, 0], | |
face_det_batch_size: int = 16, | |
wav2lip_batch_size: int = 128, | |
resize_factor: int = 1, | |
crop: list = [0, -1, 0, -1], | |
box: list = [-1, -1, -1, -1], | |
rotate: bool = False, | |
nosmooth: bool = False, | |
img_size: int = 96 # Fixed for Wav2Lip | |
) -> str: | |
""" | |
Runs the Wav2Lip inference process. | |
Args: | |
checkpoint_path (str): Path to the Wav2Lip model checkpoint. | |
face_path (str): Path to the input video/image file with a face. | |
audio_path (str): Path to the input audio file. | |
output_filename (str): Name of the output video file (e.g., 'result.mp4'). | |
static (bool): If True, use only the first video frame for inference. | |
fps (float): Frames per second for static image input. | |
pads (list): Padding for face detection (top, bottom, left, right). | |
face_det_batch_size (int): Batch size for face detection. | |
wav2lip_batch_size (int): Batch size for Wav2Lip model(s). | |
resize_factor (int): Reduce the resolution by this factor. | |
crop (list): Crop video to a smaller region (top, bottom, left, right). | |
box (list): Constant bounding box for the face. | |
rotate (bool): Rotate video right by 90deg. | |
nosmooth (bool): Prevent smoothing face detections. | |
img_size (int): Image size for the model. | |
Returns: | |
str: The path to the generated output video file. | |
""" | |
print(f"Starting inference with: face='{face_path}', audio='{audio_path}', checkpoint='{checkpoint_path}', outfile='{output_filename}'") | |
# Create necessary directories | |
output_dir = 'results' | |
temp_dir = 'temp' | |
os.makedirs(output_dir, exist_ok=True) | |
os.makedirs(temp_dir, exist_ok=True) | |
# Clear temp directory for fresh run | |
for item in os.listdir(temp_dir): | |
item_path = os.path.join(temp_dir, item) | |
if os.path.isfile(item_path): | |
os.remove(item_path) | |
elif os.path.isdir(item_path): | |
shutil.rmtree(item_path) | |
# Determine if input is static based on file extension | |
is_static_input = static or (os.path.isfile(face_path) and face_path.split('.')[-1].lower() in ['jpg', 'png', 'jpeg']) | |
full_frames = [] | |
if is_static_input: | |
full_frames = [cv2.imread(face_path)] | |
if full_frames[0] is None: | |
raise ValueError(f"Could not read face image at: {face_path}") | |
else: | |
video_stream = cv2.VideoCapture(face_path) | |
if not video_stream.isOpened(): | |
raise ValueError(f"Could not open video file at: {face_path}") | |
fps = video_stream.get(cv2.CAP_PROP_FPS) | |
print('Reading video frames...') | |
while 1: | |
still_reading, frame = video_stream.read() | |
if not still_reading: | |
video_stream.release() | |
break | |
if resize_factor > 1: | |
frame = cv2.resize(frame, (frame.shape[1]//resize_factor, frame.shape[0]//resize_factor)) | |
if rotate: | |
frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE) | |
y1, y2, x1, x2 = crop | |
if x2 == -1: x2 = frame.shape[1] | |
if y2 == -1: y2 = frame.shape[0] | |
frame = frame[y1:y2, x1:x2] | |
full_frames.append(frame) | |
print ("Number of frames available for inference: "+str(len(full_frames))) | |
if not full_frames: | |
raise ValueError("No frames could be read from the input face file.") | |
temp_audio_path = os.path.join(temp_dir, 'temp_audio.wav') | |
if not audio_path.endswith('.wav'): | |
print('Extracting raw audio...') | |
command = f'ffmpeg -y -i "{audio_path}" -strict -2 "{temp_audio_path}"' | |
try: | |
subprocess.run(command, shell=True, check=True, capture_output=True) | |
audio_path = temp_audio_path | |
except subprocess.CalledProcessError as e: | |
print(f"FFmpeg error: {e.stderr.decode()}") | |
raise RuntimeError(f"Failed to extract audio from {audio_path}. Error: {e.stderr.decode()}") | |
else: | |
# Copy the wav file to temp if it's already wav to maintain consistency in naming | |
shutil.copy(audio_path, temp_audio_path) | |
audio_path = temp_audio_path | |
wav = audio.load_wav(audio_path, 16000) | |
# >>> CRUCIAL FIX: Explicitly cast to float32 for resampy/numba compatibility <<< | |
wav = wav.astype(np.float32) | |
mel = audio.melspectrogram(wav) | |
print("Mel spectrogram shape:", mel.shape) | |
if np.isnan(mel.reshape(-1)).sum() > 0: | |
raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') | |
mel_chunks = [] | |
mel_idx_multiplier = 80./fps | |
i = 0 | |
while 1: | |
start_idx = int(i * mel_idx_multiplier) | |
if start_idx + mel_step_size > len(mel[0]): | |
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) | |
break | |
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) | |
i += 1 | |
print("Length of mel chunks: {}".format(len(mel_chunks))) | |
# Ensure full_frames matches mel_chunks length, or loop if static | |
if not is_static_input: | |
full_frames = full_frames[:len(mel_chunks)] | |
else: | |
# If static, replicate the first frame for the duration of the audio | |
full_frames = [full_frames[0]] * len(mel_chunks) | |
gen = datagen(full_frames.copy(), mel_chunks, box, is_static_input, wav2lip_batch_size, img_size, pads, face_det_batch_size, nosmooth) | |
output_avi_path = os.path.join(temp_dir, 'result.avi') | |
model_loaded = False | |
model = None | |
frame_h, frame_w = 0, 0 | |
out = None | |
for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, desc="Wav2Lip Inference", | |
total=int(np.ceil(float(len(mel_chunks))/wav2lip_batch_size)))): | |
if not model_loaded: | |
model = load_model(checkpoint_path) | |
model_loaded = True | |
print ("Model loaded successfully") | |
frame_h, frame_w = full_frames[0].shape[:-1] | |
out = cv2.VideoWriter(output_avi_path, | |
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) | |
if out is None: # In case no frames were generated for some reason | |
raise RuntimeError("Video writer could not be initialized.") | |
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) | |
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) | |
with torch.no_grad(): | |
pred = model(mel_batch, img_batch) | |
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. | |
for p, f, c in zip(pred, frames, coords): | |
y1, y2, x1, x2 = c | |
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) | |
f[y1:y2, x1:x2] = p | |
out.write(f) | |
if out: | |
out.release() | |
else: | |
print("Warning: Video writer was not initialized or no frames were processed.") | |
final_output_path = os.path.join(output_dir, output_filename) | |
command = f'ffmpeg -y -i "{audio_path}" -i "{output_avi_path}" -strict -2 -q:v 1 "{final_output_path}"' | |
try: | |
subprocess.run(command, shell=True, check=True, capture_output=True) | |
print(f"Output saved to: {final_output_path}") | |
except subprocess.CalledProcessError as e: | |
print(f"FFmpeg final merge error: {e.stderr.decode()}") | |
raise RuntimeError(f"Failed to merge audio and video. Error: {e.stderr.decode()}") | |
# Clean up temporary files (optional, but good practice) | |
# shutil.rmtree(temp_dir) # Be careful with this if you want to inspect temp files | |
return final_output_path | |
# No `if __name__ == '__main__':` block here, as it's meant to be imported |