import os import numpy as np import torch import warnings import threading import traceback import uvicorn from fastapi import FastAPI, Body from pathlib import Path from datetime import datetime import torch.distributed as dist from hymm_gradio.tool_for_end2end import * from hymm_sp.config import parse_args from hymm_sp.sample_inference_audio import HunyuanVideoSampler from hymm_sp.modules.parallel_states import ( initialize_distributed, nccl_info, ) from transformers import WhisperModel from transformers import AutoFeatureExtractor from hymm_sp.data_kits.face_align import AlignImage warnings.filterwarnings("ignore") MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE') app = FastAPI() rlock = threading.RLock() @app.api_route('/predict2', methods=['GET', 'POST']) def predict(data=Body(...)): is_acquire = False error_info = "" try: is_acquire = rlock.acquire(blocking=False) if is_acquire: res = predict_wrap(data) return res except Exception as e: error_info = traceback.format_exc() print(error_info) finally: if is_acquire: rlock.release() return {"errCode": -1, "info": "broken"} def predict_wrap(input_dict={}): if nccl_info.sp_size > 1: device = torch.device(f"cuda:{torch.distributed.get_rank()}") rank = local_rank = torch.distributed.get_rank() print(f"sp_size={nccl_info.sp_size}, rank {rank} local_rank {local_rank}") try: print(f"----- rank = {rank}") if rank == 0: input_dict = process_input_dict(input_dict) print('------- start to predict -------') # Parse input arguments image_path = input_dict["image_path"] driving_audio_path = input_dict["audio_path"] prompt = input_dict["prompt"] save_fps = input_dict.get("save_fps", 25) ret_dict = None if image_path is None or driving_audio_path is None: ret_dict = { "errCode": -3, "content": [ { "buffer": None }, ], "info": "input content is not valid", } print(f"errCode: -3, input content is not valid!") return ret_dict # Preprocess input batch torch.cuda.synchronize() a = datetime.now() try: model_kwargs_tmp = data_preprocess_server( args, image_path, driving_audio_path, prompt, feature_extractor ) except: ret_dict = { "errCode": -2, "content": [ { "buffer": None }, ], "info": "failed to preprocess input data" } print(f"errCode: -2, preprocess failed!") return ret_dict text_prompt = model_kwargs_tmp["text_prompt"] audio_path = model_kwargs_tmp["audio_path"] image_path = model_kwargs_tmp["image_path"] fps = model_kwargs_tmp["fps"] audio_prompts = model_kwargs_tmp["audio_prompts"] audio_len = model_kwargs_tmp["audio_len"] motion_bucket_id_exps = model_kwargs_tmp["motion_bucket_id_exps"] motion_bucket_id_heads = model_kwargs_tmp["motion_bucket_id_heads"] pixel_value_ref = model_kwargs_tmp["pixel_value_ref"] pixel_value_ref_llava = model_kwargs_tmp["pixel_value_ref_llava"] torch.cuda.synchronize() b = datetime.now() preprocess_time = (b - a).total_seconds() print("="*100) print("preprocess time :", preprocess_time) print("="*100) else: text_prompt = None audio_path = None image_path = None fps = None audio_prompts = None audio_len = None motion_bucket_id_exps = None motion_bucket_id_heads = None pixel_value_ref = None pixel_value_ref_llava = None except: traceback.print_exc() if rank == 0: ret_dict = { "errCode": -1, # Failed to generate video "content":[ { "buffer": None } ], "info": "failed to preprocess", } return ret_dict try: broadcast_params = [ text_prompt, audio_path, image_path, fps, audio_prompts, audio_len, motion_bucket_id_exps, motion_bucket_id_heads, pixel_value_ref, pixel_value_ref_llava, ] dist.broadcast_object_list(broadcast_params, src=0) outputs = generate_image_parallel(*broadcast_params) if rank == 0: samples = outputs["samples"] sample = samples[0].unsqueeze(0) sample = sample[:, :, :audio_len[0]] video = sample[0].permute(1, 2, 3, 0).clamp(0, 1).numpy() video = (video * 255.).astype(np.uint8) output_dict = { "err_code": 0, "err_msg": "succeed", "video": video, "audio": input_dict.get("audio_path", None), "save_fps": save_fps, } ret_dict = process_output_dict(output_dict) return ret_dict except: traceback.print_exc() if rank == 0: ret_dict = { "errCode": -1, # Failed to generate video "content":[ { "buffer": None } ], "info": "failed to generate video", } return ret_dict return None def generate_image_parallel(text_prompt, audio_path, image_path, fps, audio_prompts, audio_len, motion_bucket_id_exps, motion_bucket_id_heads, pixel_value_ref, pixel_value_ref_llava ): if nccl_info.sp_size > 1: device = torch.device(f"cuda:{torch.distributed.get_rank()}") batch = { "text_prompt": text_prompt, "audio_path": audio_path, "image_path": image_path, "fps": fps, "audio_prompts": audio_prompts, "audio_len": audio_len, "motion_bucket_id_exps": motion_bucket_id_exps, "motion_bucket_id_heads": motion_bucket_id_heads, "pixel_value_ref": pixel_value_ref, "pixel_value_ref_llava": pixel_value_ref_llava } samples = hunyuan_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance) return samples def worker_loop(): while True: predict_wrap() if __name__ == "__main__": audio_args = parse_args() initialize_distributed(audio_args.seed) hunyuan_sampler = HunyuanVideoSampler.from_pretrained( audio_args.ckpt, args=audio_args) args = hunyuan_sampler.args rank = local_rank = 0 device = torch.device("cuda") if nccl_info.sp_size > 1: device = torch.device(f"cuda:{torch.distributed.get_rank()}") rank = local_rank = torch.distributed.get_rank() feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/") wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32) wav2vec.requires_grad_(False) BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/' det_path = os.path.join(BASE_DIR, 'detface.pt') align_instance = AlignImage("cuda", det_path=det_path) if rank == 0: uvicorn.run(app, host="0.0.0.0", port=80) else: worker_loop()