Spaces:
Running
Running
| import spaces | |
| import gradio as gr | |
| import torch | |
| import tempfile | |
| import os | |
| from vae_wrapper import VaeWrapper, encode_video_chunk | |
| from landmarks_extractor import LandmarksExtractor | |
| import decord | |
| from utils import ( | |
| get_raw_audio, | |
| save_audio_video, | |
| calculate_splits, | |
| instantiate_from_config, | |
| create_pipeline_inputs, | |
| ) | |
| from transformers import HubertModel | |
| from einops import rearrange | |
| import numpy as np | |
| from WavLM import WavLM_wrapper | |
| from omegaconf import OmegaConf | |
| from inference_functions import ( | |
| sample_keyframes, | |
| sample_interpolation, | |
| ) | |
| from wordle_game import WordleGame | |
| import torch.cuda.amp as amp # Import amp for mixed precision | |
| from huggingface_hub import snapshot_download | |
| # Define the repository ID | |
| repo_id = "toninio19/keysync" | |
| # Download the entire repository | |
| repo_path = snapshot_download(repo_id=repo_id) | |
| print(f"Repository downloaded to: {repo_path}") | |
| # Set default tensor type to float16 for faster computation | |
| if torch.cuda.is_available(): | |
| # torch.set_default_tensor_type(torch.cuda.FloatTensor) | |
| # Enable TF32 precision for better performance on Ampere+ GPUs | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Cache for video and audio processing | |
| cache = { | |
| "video": { | |
| "path": None, | |
| "embedding": None, | |
| "frames": None, | |
| "landmarks": None, | |
| }, | |
| "audio": { | |
| "path": None, | |
| "raw_audio": None, | |
| "hubert_embedding": None, | |
| "wavlm_embedding": None, | |
| }, | |
| } | |
| # Create mixed precision scaler | |
| scaler = amp.GradScaler() | |
| def load_model( | |
| config: str, | |
| device: str = "cuda", | |
| ckpt: str = None, | |
| ): | |
| """ | |
| Load a model from configuration. | |
| Args: | |
| config: Path to model configuration file | |
| device: Device to load the model on | |
| num_frames: Number of frames to process | |
| input_key: Input key for the model | |
| ckpt: Optional checkpoint path | |
| Returns: | |
| Tuple of (model, filter, batch size) | |
| """ | |
| config = OmegaConf.load(config) | |
| config["model"]["params"]["input_key"] = "latents" | |
| if ckpt is not None: | |
| config.model.params.ckpt_path = ckpt | |
| with torch.device(device): | |
| model = instantiate_from_config(config.model).to(device).eval() | |
| # Convert model to half precision | |
| if torch.cuda.is_available(): | |
| model = model.half() | |
| model.first_stage_model = model.first_stage_model.float() | |
| print("Converted model to FP16 precision") | |
| # Compile model for faster inference | |
| if torch.cuda.is_available(): | |
| try: | |
| model = torch.compile(model) | |
| print(f"Successfully compiled model with torch.compile()") | |
| except Exception as e: | |
| print(f"Warning: Failed to compile model: {e}") | |
| return model | |
| # Default media paths | |
| DEFAULT_VIDEO_PATH = os.path.join( | |
| os.path.dirname(__file__), "assets", "sample_video.mp4" | |
| ) | |
| DEFAULT_AUDIO_PATH = os.path.join( | |
| os.path.dirname(__file__), "assets", "sample_audio.wav" | |
| ) | |
| # @spaces.GPU(duration=60) | |
| # def load_all_models(): | |
| # global \ | |
| # keyframe_model, \ | |
| # interpolation_model, \ | |
| # vae_model, \ | |
| # hubert_model, \ | |
| # wavlm_model, \ | |
| # landmarks_extractor | |
| # vae_model = VaeWrapper("video") | |
| # vae_model = vae_model.half() # Convert to half precision | |
| # try: | |
| # vae_model = torch.compile(vae_model) | |
| # print("Successfully compiled vae_model in FP16") | |
| # except Exception as e: | |
| # print(f"Warning: Failed to compile vae_model: {e}") | |
| # hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda() | |
| # hubert_model = hubert_model.half() # Convert to half precision | |
| # try: | |
| # hubert_model = torch.compile(hubert_model) | |
| # print("Successfully compiled hubert_model in FP16") | |
| # except Exception as e: | |
| # print(f"Warning: Failed to compile hubert_model: {e}") | |
| # wavlm_model = WavLM_wrapper( | |
| # model_size="Base+", | |
| # feed_as_frames=False, | |
| # merge_type="None", | |
| # model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"), | |
| # ).cuda() | |
| # wavlm_model = wavlm_model.half() # Convert to half precision | |
| # try: | |
| # wavlm_model = torch.compile(wavlm_model) | |
| # print("Successfully compiled wavlm_model in FP16") | |
| # except Exception as e: | |
| # print(f"Warning: Failed to compile wavlm_model: {e}") | |
| # landmarks_extractor = LandmarksExtractor() | |
| # keyframe_model = load_model( | |
| # config="keyframe.yaml", | |
| # ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"), | |
| # ) | |
| # interpolation_model = load_model( | |
| # config="interpolation.yaml", | |
| # ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"), | |
| # ) | |
| # keyframe_model.en_and_decode_n_samples_a_time = 2 | |
| # interpolation_model.en_and_decode_n_samples_a_time = 2 | |
| # return ( | |
| # keyframe_model, | |
| # interpolation_model, | |
| # vae_model, | |
| # hubert_model, | |
| # wavlm_model, | |
| # landmarks_extractor, | |
| # ) | |
| # ( | |
| # keyframe_model, | |
| # interpolation_model, | |
| # vae_model, | |
| # hubert_model, | |
| # wavlm_model, | |
| # landmarks_extractor, | |
| # ) = load_all_models() | |
| keyframe_model = None | |
| interpolation_model = None | |
| vae_model = None | |
| hubert_model = None | |
| wavlm_model = None | |
| landmarks_extractor = None | |
| def compute_video_embedding(video_reader, min_len, vae_model): | |
| """Compute embeddings from video""" | |
| total_frames = min_len | |
| encoded = [] | |
| video_frames = [] | |
| chunk_size = 16 | |
| resolution = 512 | |
| # # Create a progress bar for Gradio | |
| progress = gr.Progress() | |
| # Calculate total chunks for progress tracking | |
| total_chunks = (total_frames + chunk_size - 1) // chunk_size | |
| for i, start_idx in enumerate(range(0, total_frames, chunk_size)): | |
| # Update progress bar | |
| progress(i / total_chunks, desc="Processing video chunks") | |
| end_idx = min(start_idx + chunk_size, total_frames) | |
| video_chunk = video_reader.get_batch(range(start_idx, end_idx)) | |
| # Interpolate video chunk to the target resolution | |
| video_chunk = rearrange(video_chunk, "f h w c -> f c h w") | |
| video_chunk = torch.nn.functional.interpolate( | |
| video_chunk, | |
| size=(resolution, resolution), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| video_chunk = rearrange(video_chunk, "f c h w -> f h w c") | |
| video_frames.append(video_chunk) | |
| # Convert chunk to FP16 if using CUDA | |
| if torch.cuda.is_available(): | |
| video_chunk = video_chunk.half() | |
| # Always use autocast for FP16 computation | |
| with amp.autocast(enabled=True): | |
| encoded.append(encode_video_chunk(vae_model, video_chunk, resolution)) | |
| encoded = torch.cat(encoded, dim=0) | |
| video_frames = torch.cat(video_frames, dim=0) | |
| video_frames = rearrange(video_frames, "f h w c -> f c h w") | |
| torch.cuda.empty_cache() | |
| return encoded, video_frames | |
| def compute_hubert_embedding(raw_audio, hubert_model): | |
| """Compute embeddings from audio""" | |
| print(f"Computing audio embedding from {raw_audio.shape}") | |
| audio = ( | |
| (raw_audio - raw_audio.mean()) / torch.sqrt(raw_audio.var() + 1e-7) | |
| ).unsqueeze(0) | |
| chunks = 16000 * 20 | |
| # Create a progress bar for Gradio | |
| progress = gr.Progress() | |
| # Get audio embeddings | |
| audio_embeddings = [] | |
| splits = list(calculate_splits(audio, chunks)) | |
| total_splits = len(splits) | |
| for i, chunk in enumerate(splits): | |
| # Update progress bar | |
| progress(i / total_splits, desc="Processing audio chunks") | |
| # Convert audio chunk to half precision | |
| if torch.cuda.is_available(): | |
| chunk_cuda = chunk.cuda().half() | |
| else: | |
| chunk_cuda = chunk.cuda() | |
| # Always use autocast for FP16 computation | |
| with amp.autocast(enabled=True): | |
| hidden_states = hubert_model(chunk_cuda)[0] | |
| audio_embeddings.append(hidden_states) | |
| audio_embeddings = torch.cat(audio_embeddings, dim=1) | |
| # audio_embeddings = self.model.wav2vec2(rearrange(audio_frames, "f s -> () (f s)"))[0] | |
| if audio_embeddings.shape[1] % 2 != 0: | |
| audio_embeddings = torch.cat( | |
| [audio_embeddings, torch.zeros_like(audio_embeddings[:, :1])], dim=1 | |
| ) | |
| audio_embeddings = rearrange(audio_embeddings, "() (f d) c -> f d c", d=2) | |
| torch.cuda.empty_cache() | |
| return audio_embeddings | |
| def compute_wavlm_embedding(raw_audio, wavlm_model): | |
| """Compute embeddings from audio""" | |
| audio = rearrange(raw_audio, "(f s) -> f s", s=640) | |
| if audio.shape[0] % 2 != 0: | |
| audio = torch.cat([audio, torch.zeros(1, 640)], dim=0) | |
| chunks = 500 | |
| # Create a progress bar for Gradio | |
| progress = gr.Progress() | |
| # Get audio embeddings | |
| audio_embeddings = [] | |
| splits = list(calculate_splits(audio, chunks)) | |
| total_splits = len(splits) | |
| for i, chunk in enumerate(splits): | |
| # Update progress bar | |
| progress(i / total_splits, desc="Processing audio chunks") | |
| # Convert chunk to half precision | |
| if torch.cuda.is_available(): | |
| chunk_cuda = chunk.unsqueeze(0).cuda().half() | |
| else: | |
| chunk_cuda = chunk.unsqueeze(0).cuda() | |
| # Always use autocast for FP16 computation | |
| with amp.autocast(enabled=True): | |
| wavlm_hidden_states = wavlm_model(chunk_cuda).squeeze(0) | |
| audio_embeddings.append(wavlm_hidden_states) | |
| audio_embeddings = torch.cat(audio_embeddings, dim=0) | |
| torch.cuda.empty_cache() | |
| return audio_embeddings | |
| def extract_video_landmarks(video_frames, landmarks_extractor): | |
| """Extract landmarks from video frames""" | |
| # Create a progress bar for Gradio | |
| progress = gr.Progress() | |
| landmarks = [] | |
| batch_size = 10 | |
| for i in range(0, len(video_frames), batch_size): | |
| # Update progress bar | |
| progress(i / len(video_frames), desc="Extracting facial landmarks") | |
| batch = video_frames[i : i + batch_size].cpu().float() | |
| batch_landmarks = landmarks_extractor.extract_landmarks(batch) | |
| landmarks.extend(batch_landmarks) | |
| torch.cuda.empty_cache() | |
| # Convert landmarks to a list of numpy arrays with consistent shape | |
| processed_landmarks = [] | |
| expected_shape = (68, 2) # Common shape for facial landmarks | |
| # Process each landmark to ensure consistent shape | |
| last_valid_landmark = None | |
| for i, lm in enumerate(landmarks): | |
| if lm is not None and isinstance(lm, np.ndarray) and lm.shape == expected_shape: | |
| processed_landmarks.append(lm) | |
| last_valid_landmark = lm | |
| else: | |
| # Print information about inconsistent landmarks | |
| if lm is None: | |
| print(f"Warning: Landmark at index {i} is None") | |
| elif not isinstance(lm, np.ndarray): | |
| print( | |
| f"Warning: Landmark at index {i} is not a numpy array, type: {type(lm)}" | |
| ) | |
| elif lm.shape != expected_shape: | |
| print( | |
| f"Warning: Landmark at index {i} has shape {lm.shape}, expected {expected_shape}" | |
| ) | |
| # Replace invalid landmarks with the closest valid landmark if available | |
| if last_valid_landmark is not None: | |
| processed_landmarks.append(last_valid_landmark.copy()) | |
| else: | |
| # If no valid landmark has been seen yet, look ahead for a valid one | |
| found_future_valid = False | |
| for future_lm in landmarks[i + 1 :]: | |
| if ( | |
| future_lm is not None | |
| and isinstance(future_lm, np.ndarray) | |
| and future_lm.shape == expected_shape | |
| ): | |
| processed_landmarks.append(future_lm.copy()) | |
| found_future_valid = True | |
| break | |
| # If no valid landmark found in the future, use zeros | |
| if not found_future_valid: | |
| processed_landmarks.append(np.zeros(expected_shape)) | |
| return np.array(processed_landmarks) | |
| def sample( | |
| audio_list, | |
| gt_keyframes, | |
| masks_keyframes, | |
| to_remove, | |
| test_keyframes_list, | |
| num_frames, | |
| device, | |
| emb, | |
| force_uc_zero_embeddings, | |
| n_batch_keyframes, | |
| n_batch, | |
| test_interpolation_list, | |
| audio_interpolation_list, | |
| masks_interpolation, | |
| gt_interpolation, | |
| model_keyframes, | |
| model, | |
| ): | |
| # Create a progress bar for Gradio | |
| progress = gr.Progress() | |
| condition = torch.zeros(1, 3, 512, 512).to(device) | |
| if torch.cuda.is_available(): | |
| condition = condition.half() | |
| audio_list = rearrange(audio_list, "(b t) c d -> b t c d", t=num_frames) | |
| gt_keyframes = rearrange(gt_keyframes, "(b t) c h w -> b t c h w", t=num_frames) | |
| # Rearrange masks_keyframes and save locally | |
| masks_keyframes = rearrange( | |
| masks_keyframes, "(b t) c h w -> b t c h w", t=num_frames | |
| ) | |
| # Convert to_remove into chunks of num_frames | |
| to_remove_chunks = [ | |
| to_remove[i : i + num_frames] for i in range(0, len(to_remove), num_frames) | |
| ] | |
| test_keyframes_list = [ | |
| test_keyframes_list[i : i + num_frames] | |
| for i in range(0, len(test_keyframes_list), num_frames) | |
| ] | |
| audio_cond = audio_list | |
| if emb is not None: | |
| embbedings = emb.unsqueeze(0).to(device) | |
| if torch.cuda.is_available(): | |
| embbedings = embbedings.half() | |
| else: | |
| embbedings = None | |
| # One batch of keframes is approximately 7 seconds | |
| chunk_size = 2 | |
| complete_video = [] | |
| start_idx = 0 | |
| last_frame_z = None | |
| last_frame_x = None | |
| last_keyframe_idx = None | |
| last_to_remove = None | |
| total_chunks = (len(audio_cond) + chunk_size - 1) // chunk_size | |
| for chunk_idx, chunk_start in enumerate(range(0, len(audio_cond), chunk_size)): | |
| # Update progress bar | |
| progress(chunk_idx / total_chunks, desc="Generating video") | |
| # Clear GPU cache between chunks | |
| torch.cuda.empty_cache() | |
| chunk_end = min(chunk_start + chunk_size, len(audio_cond)) | |
| chunk_audio_cond = audio_cond[chunk_start:chunk_end].cuda() | |
| if torch.cuda.is_available(): | |
| chunk_audio_cond = chunk_audio_cond.half() | |
| chunk_gt_keyframes = gt_keyframes[chunk_start:chunk_end].cuda() | |
| chunk_masks = masks_keyframes[chunk_start:chunk_end].cuda() | |
| if torch.cuda.is_available(): | |
| chunk_gt_keyframes = chunk_gt_keyframes.half() | |
| chunk_masks = chunk_masks.half() | |
| test_keyframes_list_unwrapped = [ | |
| elem | |
| for sublist in test_keyframes_list[chunk_start:chunk_end] | |
| for elem in sublist | |
| ] | |
| to_remove_chunks_unwrapped = [ | |
| elem | |
| for sublist in to_remove_chunks[chunk_start:chunk_end] | |
| for elem in sublist | |
| ] | |
| if last_keyframe_idx is not None: | |
| test_keyframes_list_unwrapped = [ | |
| last_keyframe_idx | |
| ] + test_keyframes_list_unwrapped | |
| to_remove_chunks_unwrapped = [last_to_remove] + to_remove_chunks_unwrapped | |
| last_keyframe_idx = test_keyframes_list_unwrapped[-1] | |
| last_to_remove = to_remove_chunks_unwrapped[-1] | |
| # Find the first non-None keyframe in the chunk | |
| first_keyframe = next( | |
| (kf for kf in test_keyframes_list_unwrapped if kf is not None), None | |
| ) | |
| # Find the last non-None keyframe in the chunk | |
| last_keyframe = next( | |
| (kf for kf in reversed(test_keyframes_list_unwrapped) if kf is not None), | |
| None, | |
| ) | |
| start_idx = next( | |
| ( | |
| idx | |
| for idx, comb in enumerate(test_interpolation_list) | |
| if comb[0] == first_keyframe | |
| ), | |
| None, | |
| ) | |
| end_idx = next( | |
| ( | |
| idx | |
| for idx, comb in enumerate(reversed(test_interpolation_list)) | |
| if comb[1] == last_keyframe | |
| ), | |
| None, | |
| ) | |
| if start_idx is not None and end_idx is not None: | |
| end_idx = ( | |
| len(test_interpolation_list) - 1 - end_idx | |
| ) # Adjust for reversed enumeration | |
| end_idx += 1 | |
| if start_idx is None: | |
| break | |
| if end_idx < start_idx: | |
| end_idx = len(audio_interpolation_list) | |
| audio_interpolation_list_chunk = audio_interpolation_list[start_idx:end_idx] | |
| chunk_masks_interpolation = masks_interpolation[start_idx:end_idx] | |
| gt_interpolation_chunks = gt_interpolation[start_idx:end_idx] | |
| if torch.cuda.is_available(): | |
| audio_interpolation_list_chunk = [ | |
| chunk.half() for chunk in audio_interpolation_list_chunk | |
| ] | |
| chunk_masks_interpolation = [ | |
| chunk.half() for chunk in chunk_masks_interpolation | |
| ] | |
| gt_interpolation_chunks = [ | |
| chunk.half() for chunk in gt_interpolation_chunks | |
| ] | |
| progress(chunk_idx / total_chunks, desc="Generating keyframes") | |
| # Always use autocast for FP16 computation | |
| with amp.autocast(enabled=True): | |
| samples_z = sample_keyframes( | |
| model_keyframes, | |
| chunk_audio_cond, | |
| chunk_gt_keyframes, | |
| chunk_masks, | |
| condition.cuda(), | |
| num_frames, | |
| 24, | |
| 0.0, | |
| device, | |
| embbedings.cuda() if embbedings is not None else None, | |
| force_uc_zero_embeddings, | |
| n_batch_keyframes, | |
| 0, | |
| 1.0, | |
| None, | |
| gt_as_cond=False, | |
| ) | |
| if last_frame_x is not None: | |
| # samples_x = torch.cat([last_frame_x.unsqueeze(0), samples_x], axis=0) | |
| samples_z = torch.cat([last_frame_z.unsqueeze(0), samples_z], axis=0) | |
| # last_frame_x = samples_x[-1] | |
| last_frame_z = samples_z[-1] | |
| progress(chunk_idx / total_chunks, desc="Interpolating frames") | |
| # Always use autocast for FP16 computation | |
| with amp.autocast(enabled=True): | |
| vid = sample_interpolation( | |
| model, | |
| samples_z, | |
| # samples_x, | |
| audio_interpolation_list_chunk, | |
| gt_interpolation_chunks, | |
| chunk_masks_interpolation, | |
| condition.cuda(), | |
| num_frames, | |
| device, | |
| 1, | |
| 24, | |
| 0.0, | |
| force_uc_zero_embeddings, | |
| n_batch, | |
| chunk_size, | |
| 1.0, | |
| None, | |
| cut_audio=False, | |
| to_remove=to_remove_chunks_unwrapped, | |
| ) | |
| if chunk_start == 0: | |
| complete_video = vid | |
| else: | |
| complete_video = np.concatenate([complete_video[:-1], vid], axis=0) | |
| return complete_video | |
| def process_video(video_input, audio_input): | |
| """Main processing function to generate synchronized video""" | |
| # Display a message to the user about the processing time | |
| gr.Info("Processing video. This may take a while...", duration=10) | |
| gr.Info( | |
| "If you're tired of waiting, try playing the Wordle game in the other tab!", | |
| duration=10, | |
| ) | |
| max_num_seconds = 6 | |
| global \ | |
| vae_model, \ | |
| hubert_model, \ | |
| wavlm_model, \ | |
| landmarks_extractor, \ | |
| keyframe_model, \ | |
| interpolation_model | |
| if vae_model is None: | |
| vae_model = VaeWrapper("video") | |
| vae_model = vae_model.half() # Convert to half precision | |
| try: | |
| vae_model = torch.compile(vae_model) | |
| print("Successfully compiled vae_model in FP16") | |
| except Exception as e: | |
| print(f"Warning: Failed to compile vae_model: {e}") | |
| if hubert_model is None: | |
| hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda() | |
| hubert_model = hubert_model.half() # Convert to half precision | |
| try: | |
| hubert_model = torch.compile(hubert_model) | |
| print("Successfully compiled hubert_model in FP16") | |
| except Exception as e: | |
| print(f"Warning: Failed to compile hubert_model: {e}") | |
| if wavlm_model is None: | |
| wavlm_model = WavLM_wrapper( | |
| model_size="Base+", | |
| feed_as_frames=False, | |
| merge_type="None", | |
| model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"), | |
| ).cuda() | |
| wavlm_model = wavlm_model.half() # Convert to half precision | |
| try: | |
| wavlm_model = torch.compile(wavlm_model) | |
| print("Successfully compiled wavlm_model in FP16") | |
| except Exception as e: | |
| print(f"Warning: Failed to compile wavlm_model: {e}") | |
| if landmarks_extractor is None: | |
| landmarks_extractor = LandmarksExtractor() | |
| if keyframe_model is None: | |
| keyframe_model = load_model( | |
| config="keyframe.yaml", | |
| ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"), | |
| ) | |
| if interpolation_model is None: | |
| interpolation_model = load_model( | |
| config="interpolation.yaml", | |
| ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"), | |
| ) | |
| keyframe_model.en_and_decode_n_samples_a_time = 2 | |
| interpolation_model.en_and_decode_n_samples_a_time = 2 | |
| # Use default media if none provided | |
| if video_input is None: | |
| video_input = DEFAULT_VIDEO_PATH | |
| print(f"Using default video: {DEFAULT_VIDEO_PATH}") | |
| if audio_input is None: | |
| audio_input = DEFAULT_AUDIO_PATH | |
| print(f"Using default audio: {DEFAULT_AUDIO_PATH}") | |
| # try: | |
| # Calculate hashes for cache keys | |
| video_path_hash = video_input | |
| audio_path_hash = audio_input | |
| # Check if we need to recompute video embeddings | |
| video_cache_hit = cache["video"]["path"] == video_path_hash | |
| audio_cache_hit = cache["audio"]["path"] == audio_path_hash | |
| if video_cache_hit and audio_cache_hit: | |
| print("Using cached video and audio computations") | |
| # Make copies of cached data to avoid modifying cache | |
| video_embedding = cache["video"]["embedding"].clone() | |
| video_frames = cache["video"]["frames"].clone() | |
| video_landmarks = cache["video"]["landmarks"].copy() | |
| raw_audio = cache["audio"]["raw_audio"].clone() | |
| raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
| hubert_embedding = cache["audio"]["hubert_embedding"].clone() | |
| wavlm_embedding = cache["audio"]["wavlm_embedding"].clone() | |
| # Ensure all data is truncated to the same length if needed | |
| min_len = min( | |
| len(video_frames), | |
| len(raw_audio), | |
| len(hubert_embedding), | |
| len(wavlm_embedding), | |
| ) | |
| video_frames = video_frames[:min_len] | |
| video_embedding = video_embedding[:min_len] | |
| video_landmarks = video_landmarks[:min_len] | |
| raw_audio = raw_audio[:min_len] | |
| hubert_embedding = hubert_embedding[:min_len] | |
| wavlm_embedding = wavlm_embedding[:min_len] | |
| raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
| else: | |
| # Process video if needed | |
| if not video_cache_hit: | |
| print("Computing video embeddings and landmarks") | |
| video_reader = decord.VideoReader(video_input) | |
| decord.bridge.set_bridge("torch") | |
| if not audio_cache_hit: | |
| # Need to process audio to determine min_len | |
| raw_audio = get_raw_audio(audio_input, 16000) | |
| if len(raw_audio) == 0 or len(video_reader) == 0: | |
| raise ValueError("Empty audio or video input") | |
| min_len = min(len(raw_audio), len(video_reader)) | |
| # Store full audio in cache | |
| cache["audio"]["path"] = audio_path_hash | |
| cache["audio"]["raw_audio"] = raw_audio.clone() | |
| # Create truncated copy for processing | |
| raw_audio = raw_audio[:min_len] | |
| raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
| else: | |
| # Use cached audio - make a copy | |
| if cache["audio"]["raw_audio"] is None: | |
| raise ValueError("Cached audio is None") | |
| raw_audio = cache["audio"]["raw_audio"].clone() | |
| if len(raw_audio) == 0 or len(video_reader) == 0: | |
| raise ValueError("Empty cached audio or video input") | |
| min_len = min(len(raw_audio), len(video_reader)) | |
| # Create truncated copy for processing | |
| raw_audio = raw_audio[:min_len] | |
| raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
| # Compute video embeddings and landmarks - store full version in cache | |
| video_embedding, video_frames = compute_video_embedding( | |
| video_reader, len(video_reader), vae_model | |
| ) | |
| video_landmarks = extract_video_landmarks(video_frames, landmarks_extractor) | |
| # Update video cache with full versions | |
| cache["video"]["path"] = video_path_hash | |
| cache["video"]["embedding"] = video_embedding | |
| cache["video"]["frames"] = video_frames | |
| cache["video"]["landmarks"] = video_landmarks | |
| # Create truncated copies for processing | |
| video_embedding = video_embedding[:min_len] | |
| video_frames = video_frames[:min_len] | |
| video_landmarks = video_landmarks[:min_len] | |
| else: | |
| # Use cached video data - make copies | |
| print("Using cached video computations") | |
| if ( | |
| cache["video"]["embedding"] is None | |
| or cache["video"]["frames"] is None | |
| or cache["video"]["landmarks"] is None | |
| ): | |
| raise ValueError("One or more video cache entries are None") | |
| if not audio_cache_hit: | |
| # New audio with cached video | |
| raw_audio = get_raw_audio(audio_input, 16000) | |
| if len(raw_audio) == 0: | |
| raise ValueError("Empty audio input") | |
| # Store full audio in cache | |
| cache["audio"]["path"] = audio_path_hash | |
| cache["audio"]["raw_audio"] = raw_audio.clone() | |
| # Make copies of video data | |
| video_embedding = cache["video"]["embedding"].clone() | |
| video_frames = cache["video"]["frames"].clone() | |
| video_landmarks = cache["video"]["landmarks"].copy() | |
| # Determine truncation length and create truncated copies | |
| min_len = min(len(raw_audio), len(video_frames)) | |
| raw_audio = raw_audio[:min_len] | |
| raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
| video_frames = video_frames[:min_len] | |
| video_embedding = video_embedding[:min_len] | |
| video_landmarks = video_landmarks[:min_len] | |
| else: | |
| # Both video and audio are cached - should not reach here | |
| # as it's handled in the first if statement | |
| pass | |
| # Process audio if needed | |
| if not audio_cache_hit: | |
| print("Computing audio embeddings") | |
| # Compute audio embeddings with the truncated audio | |
| hubert_embedding = compute_hubert_embedding(raw_audio_reshape, hubert_model) | |
| wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape, wavlm_model) | |
| # Update audio cache with full embeddings | |
| # Note: raw_audio was already cached above | |
| cache["audio"]["hubert_embedding"] = hubert_embedding.clone() | |
| cache["audio"]["wavlm_embedding"] = wavlm_embedding.clone() | |
| else: | |
| # Use cached audio data - make copies | |
| if ( | |
| cache["audio"]["hubert_embedding"] is None | |
| or cache["audio"]["wavlm_embedding"] is None | |
| ): | |
| raise ValueError("One or more audio embedding cache entries are None") | |
| hubert_embedding = cache["audio"]["hubert_embedding"].clone() | |
| wavlm_embedding = cache["audio"]["wavlm_embedding"].clone() | |
| # Make sure embeddings match the truncated video length if needed | |
| if "min_len" in locals() and ( | |
| min_len < len(hubert_embedding) or min_len < len(wavlm_embedding) | |
| ): | |
| hubert_embedding = hubert_embedding[:min_len] | |
| wavlm_embedding = wavlm_embedding[:min_len] | |
| # Apply max_num_seconds limit if specified | |
| if max_num_seconds > 0: | |
| # Convert seconds to frames (assuming 25 fps) | |
| max_frames = int(max_num_seconds * 25) | |
| # Truncate all data to max_frames | |
| video_embedding = video_embedding[:max_frames] | |
| video_frames = video_frames[:max_frames] | |
| video_landmarks = video_landmarks[:max_frames] | |
| hubert_embedding = hubert_embedding[:max_frames] | |
| wavlm_embedding = wavlm_embedding[:max_frames] | |
| raw_audio = raw_audio[:max_frames] | |
| raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
| # Validate shapes before proceeding | |
| assert video_embedding.shape[0] == hubert_embedding.shape[0], ( | |
| f"Video embedding length ({video_embedding.shape[0]}) doesn't match Hubert embedding length ({hubert_embedding.shape[0]})" | |
| ) | |
| assert video_embedding.shape[0] == wavlm_embedding.shape[0], ( | |
| f"Video embedding length ({video_embedding.shape[0]}) doesn't match WavLM embedding length ({wavlm_embedding.shape[0]})" | |
| ) | |
| assert video_embedding.shape[0] == video_landmarks.shape[0], ( | |
| f"Video embedding length ({video_embedding.shape[0]}) doesn't match landmarks length ({video_landmarks.shape[0]})" | |
| ) | |
| print(f"Hubert embedding shape: {hubert_embedding.shape}") | |
| print(f"WavLM embedding shape: {wavlm_embedding.shape}") | |
| print(f"Video embedding shape: {video_embedding.shape}") | |
| print(f"Video landmarks shape: {video_landmarks.shape}") | |
| # Create pipeline inputs for models | |
| ( | |
| interpolation_chunks, | |
| keyframe_chunks, | |
| audio_interpolation_chunks, | |
| audio_keyframe_chunks, | |
| emb_cond, | |
| masks_keyframe_chunks, | |
| masks_interpolation_chunks, | |
| to_remove, | |
| audio_interpolation_idx, | |
| audio_keyframe_idx, | |
| ) = create_pipeline_inputs( | |
| hubert_embedding, | |
| wavlm_embedding, | |
| 14, | |
| video_embedding, | |
| video_landmarks, | |
| overlap=1, | |
| add_zero_flag=True, | |
| mask_arms=None, | |
| nose_index=28, | |
| ) | |
| complete_video = sample( | |
| audio_keyframe_chunks, | |
| keyframe_chunks, | |
| masks_keyframe_chunks, | |
| to_remove, | |
| audio_keyframe_idx, | |
| 14, | |
| "cuda", | |
| emb_cond, | |
| [], | |
| 3, | |
| 3, | |
| audio_interpolation_idx, | |
| audio_interpolation_chunks, | |
| masks_interpolation_chunks, | |
| interpolation_chunks, | |
| keyframe_model, | |
| interpolation_model, | |
| ) | |
| complete_audio = rearrange(raw_audio[: complete_video.shape[0]], "f s -> () (f s)") | |
| # 4. Convert frames to video and combine with audio | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video: | |
| output_path = temp_video.name | |
| print("Saving video to", output_path) | |
| save_audio_video(complete_video, audio=complete_audio, save_path=output_path) | |
| torch.cuda.empty_cache() | |
| return output_path | |
| # except Exception as e: | |
| # raise e | |
| # print(f"Error processing video: {str(e)}") | |
| # return None | |
| # def get_max_duration(video_input, audio_input): | |
| # """Get the maximum duration in seconds for the slider""" | |
| # try: | |
| # # Default to 60 seconds if files don't exist | |
| # if video_input is None or not os.path.exists(video_input): | |
| # video_input = DEFAULT_VIDEO_PATH | |
| # if audio_input is None or not os.path.exists(audio_input): | |
| # audio_input = DEFAULT_AUDIO_PATH | |
| # # Get video duration | |
| # video_reader = decord.VideoReader(video_input) | |
| # video_duration = len(video_reader) / video_reader.get_avg_fps() | |
| # # Get audio duration | |
| # raw_audio = get_raw_audio(audio_input, 16000) | |
| # audio_duration = len(raw_audio) / 25 # Assuming 25 fps | |
| # # Return the minimum of the two durations | |
| # return min(video_duration, audio_duration) | |
| # except Exception as e: | |
| # print(f"Error getting max duration: {str(e)}") | |
| # return 60 # Default to 60 seconds | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="KeySync: A Robust Approach for Leakage-free Lip Synchronization in High Resolution" | |
| ) as demo: | |
| gr.Markdown( | |
| "# KeySync: A Robust Approach for Leakage-free Lip Synchronization in High Resolution" | |
| ) | |
| gr.Markdown( | |
| "Upload a video and audio to create a synchronized video with the same visuals but synchronized to the new audio." | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("Video Synchronization"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video( | |
| label="Input Video", | |
| value=DEFAULT_VIDEO_PATH | |
| if os.path.exists(DEFAULT_VIDEO_PATH) | |
| else None, | |
| width=512, | |
| height=512, | |
| ) | |
| audio_input = gr.Audio( | |
| label="Input Audio", | |
| type="filepath", | |
| value=DEFAULT_AUDIO_PATH | |
| if os.path.exists(DEFAULT_AUDIO_PATH) | |
| else None, | |
| ) | |
| # max_duration = gr.State(value=60) # Default max duration | |
| # max_seconds_slider = gr.Slider( | |
| # minimum=0, | |
| # maximum=60, # Will be updated dynamically | |
| # value=0, | |
| # step=1, | |
| # label="Max Duration (seconds, 0 = full length)", | |
| # info="Limit the processing duration (0 means use full length)", | |
| # ) | |
| process_button = gr.Button("Generate Synchronized Video") | |
| with gr.Column("Output Video"): | |
| video_output = gr.Video(label="Output Video", width=512, height=512) | |
| # # Update slider max value when inputs change | |
| # def update_slider_max(video, audio): | |
| # max_dur = get_max_duration(video, audio) | |
| # return {"maximum": max_dur, "__type__": "update"} | |
| # video_input.change( | |
| # update_slider_max, [video_input, audio_input], [max_seconds_slider] | |
| # ) | |
| # audio_input.change( | |
| # update_slider_max, [video_input, audio_input], [max_seconds_slider] | |
| # ) | |
| # Show Wordle message when processing starts and hide when complete | |
| process_button.click( | |
| fn=process_video, | |
| inputs=[video_input, audio_input], | |
| outputs=video_output, | |
| ) | |
| gr.Markdown("## How it works") | |
| gr.Markdown(""" | |
| 1. The system extracts embeddings and landmarks from the input video | |
| 2. Audio embeddings are computed from the input audio | |
| 3. A keyframe model generates key visual frames | |
| 4. An interpolation model creates a smooth video between keyframes | |
| 5. The final video is rendered with the new audio | |
| """) | |
| gr.Markdown(""" | |
| ## Limitations | |
| Due to GPU restrictions on Hugging Face Spaces, the demo is limited to processing videos of maximum 6 seconds in length. For longer videos or better performance, we recommend using the inference scripts provided in this repository (https://github.com/antonibigata/keysync) to run KeySync locally on your own hardware. | |
| """) | |
| if __name__ == "__main__": | |
| import spaces | |
| demo.launch() | |