Spaces:
Paused
Paused
| import os | |
| import numpy as np | |
| from pathlib import Path | |
| from loguru import logger | |
| import torch | |
| from einops import rearrange | |
| import torch.distributed | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.utils.data import DataLoader | |
| from hymm_sp.config import parse_args | |
| from hymm_sp.sample_inference_audio import HunyuanVideoSampler | |
| from hymm_sp.data_kits.audio_dataset import VideoAudioTextLoaderVal | |
| from hymm_sp.data_kits.face_align import AlignImage | |
| from transformers import WhisperModel | |
| from transformers import AutoFeatureExtractor | |
| MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE') | |
| def main(): | |
| args = parse_args() | |
| models_root_path = Path(args.ckpt) | |
| if not models_root_path.exists(): | |
| raise ValueError(f"`models_root` not exists: {models_root_path}") | |
| # Create save folder to save the samples | |
| save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}' | |
| if not os.path.exists(args.save_path): | |
| os.makedirs(save_path, exist_ok=True) | |
| # Load models | |
| rank = 0 | |
| vae_dtype = torch.float16 | |
| device = torch.device("cuda") | |
| hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(args.ckpt, args=args, device=device) | |
| # Get the updated args | |
| args = hunyuan_video_sampler.args | |
| if args.cpu_offload: | |
| from diffusers.hooks import apply_group_offloading | |
| onload_device = torch.device("cuda") | |
| apply_group_offloading(hunyuan_video_sampler.pipeline.transformer, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=1) | |
| wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32) | |
| wav2vec.requires_grad_(False) | |
| BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/' | |
| det_path = os.path.join(BASE_DIR, 'detface.pt') | |
| align_instance = AlignImage("cuda", det_path=det_path) | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/") | |
| kwargs = { | |
| "text_encoder": hunyuan_video_sampler.text_encoder, | |
| "text_encoder_2": hunyuan_video_sampler.text_encoder_2, | |
| "feature_extractor": feature_extractor, | |
| } | |
| video_dataset = VideoAudioTextLoaderVal( | |
| image_size=args.image_size, | |
| meta_file=args.input, | |
| **kwargs, | |
| ) | |
| sampler = DistributedSampler(video_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False) | |
| json_loader = DataLoader(video_dataset, batch_size=1, shuffle=False, sampler=sampler, drop_last=False) | |
| for batch_index, batch in enumerate(json_loader, start=1): | |
| fps = batch["fps"] | |
| videoid = batch['videoid'][0] | |
| audio_path = str(batch["audio_path"][0]) | |
| save_path = args.save_path | |
| output_path = f"{save_path}/{videoid}.mp4" | |
| output_audio_path = f"{save_path}/{videoid}_audio.mp4" | |
| if args.infer_min: | |
| batch["audio_len"][0] = 129 | |
| samples = hunyuan_video_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance) | |
| sample = samples['samples'][0].unsqueeze(0) # denoised latent, (bs, 16, t//4, h//8, w//8) | |
| sample = sample[:, :, :batch["audio_len"][0]] | |
| video = rearrange(sample[0], "c f h w -> f h w c") | |
| video = (video * 255.).data.cpu().numpy().astype(np.uint8) # (f h w c) | |
| torch.cuda.empty_cache() | |
| final_frames = [] | |
| for frame in video: | |
| final_frames.append(frame) | |
| final_frames = np.stack(final_frames, axis=0) | |
| if rank == 0: | |
| from hymm_sp.data_kits.ffmpeg_utils import save_video | |
| save_video(final_frames, output_path, n_rows=len(final_frames), fps=fps.item()) | |
| os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{output_audio_path}' -y -loglevel quiet; rm '{output_path}'") | |
| if __name__ == "__main__": | |
| main() | |