Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023-2024, Qi Zuo | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| os.system('rm -rf /data-nvme/zerogpu-offload/') | |
| os.system('pip install numpy==1.23.0') | |
| os.system('pip install ./wheels/pytorch3d-0.7.3-cp310-cp310-linux_x86_64.whl') | |
| import argparse | |
| import base64 | |
| import time | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| import gradio as gr | |
| import spaces | |
| from flame_tracking_single_image import FlameTrackingSingleImage | |
| from ffmpeg_utils import images_to_video | |
| # torch._dynamo.config.disable = True | |
| def parse_configs(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str) | |
| parser.add_argument('--infer', type=str) | |
| args, unknown = parser.parse_known_args() | |
| cfg = OmegaConf.create() | |
| cli_cfg = OmegaConf.from_cli(unknown) | |
| # parse from ENV | |
| if os.environ.get('APP_INFER') is not None: | |
| args.infer = os.environ.get('APP_INFER') | |
| if os.environ.get('APP_MODEL_NAME') is not None: | |
| cli_cfg.model_name = os.environ.get('APP_MODEL_NAME') | |
| args.config = args.infer if args.config is None else args.config | |
| if args.config is not None: | |
| cfg_train = OmegaConf.load(args.config) | |
| cfg.source_size = cfg_train.dataset.source_image_res | |
| try: | |
| cfg.src_head_size = cfg_train.dataset.src_head_size | |
| except: | |
| cfg.src_head_size = 112 | |
| cfg.render_size = cfg_train.dataset.render_image.high | |
| _relative_path = os.path.join( | |
| cfg_train.experiment.parent, | |
| cfg_train.experiment.child, | |
| os.path.basename(cli_cfg.model_name).split('_')[-1], | |
| ) | |
| cfg.save_tmp_dump = os.path.join('exps', 'save_tmp', _relative_path) | |
| cfg.image_dump = os.path.join('exps', 'images', _relative_path) | |
| cfg.video_dump = os.path.join('exps', 'videos', | |
| _relative_path) # output path | |
| if args.infer is not None: | |
| cfg_infer = OmegaConf.load(args.infer) | |
| cfg.merge_with(cfg_infer) | |
| cfg.setdefault('save_tmp_dump', | |
| os.path.join('exps', cli_cfg.model_name, 'save_tmp')) | |
| cfg.setdefault('image_dump', | |
| os.path.join('exps', cli_cfg.model_name, 'images')) | |
| cfg.setdefault('video_dump', | |
| os.path.join('dumps', cli_cfg.model_name, 'videos')) | |
| cfg.setdefault('mesh_dump', | |
| os.path.join('dumps', cli_cfg.model_name, 'meshes')) | |
| cfg.motion_video_read_fps = 6 | |
| cfg.merge_with(cli_cfg) | |
| cfg.setdefault('logger', 'INFO') | |
| assert cfg.model_name is not None, 'model_name is required' | |
| return cfg, cfg_train | |
| def launch_pretrained(): | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| hf_hub_download(repo_id='yuandong513/flametracking_model', | |
| repo_type='model', | |
| filename='pretrain_model.tar', | |
| local_dir='./') | |
| os.system('tar -xf pretrain_model.tar && rm pretrain_model.tar') | |
| def animation_infer(renderer, gs_model_list, query_points, smplx_params, | |
| render_c2ws, render_intrs, render_bg_colors): | |
| '''Inference code avoid repeat forward. | |
| ''' | |
| render_h, render_w = int(render_intrs[0, 0, 1, 2] * 2), int( | |
| render_intrs[0, 0, 0, 2] * 2) | |
| # render target views | |
| render_res_list = [] | |
| num_views = render_c2ws.shape[1] | |
| start_time = time.time() | |
| # render target views | |
| render_res_list = [] | |
| for view_idx in range(num_views): | |
| render_res = renderer.forward_animate_gs( | |
| gs_model_list, | |
| query_points, | |
| renderer.get_single_view_smpl_data(smplx_params, view_idx), | |
| render_c2ws[:, view_idx:view_idx + 1], | |
| render_intrs[:, view_idx:view_idx + 1], | |
| render_h, | |
| render_w, | |
| render_bg_colors[:, view_idx:view_idx + 1], | |
| ) | |
| render_res_list.append(render_res) | |
| print( | |
| f'time elpased(animate gs model per frame):{(time.time() - start_time)/num_views}' | |
| ) | |
| out = defaultdict(list) | |
| for res in render_res_list: | |
| for k, v in res.items(): | |
| if isinstance(v[0], torch.Tensor): | |
| out[k].append(v.detach().cpu()) | |
| else: | |
| out[k].append(v) | |
| for k, v in out.items(): | |
| # print(f"out key:{k}") | |
| if isinstance(v[0], torch.Tensor): | |
| out[k] = torch.concat(v, dim=1) | |
| if k in ['comp_rgb', 'comp_mask', 'comp_depth']: | |
| out[k] = out[k][0].permute( | |
| 0, 2, 3, | |
| 1) # [1, Nv, 3, H, W] -> [Nv, 3, H, W] - > [Nv, H, W, 3] | |
| else: | |
| out[k] = v | |
| return out | |
| def assert_input_image(input_image): | |
| if input_image is None: | |
| raise gr.Error('No image selected or uploaded!') | |
| def prepare_working_dir(): | |
| import tempfile | |
| working_dir = tempfile.TemporaryDirectory() | |
| return working_dir | |
| def get_image_base64(path): | |
| with open(path, 'rb') as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode() | |
| return f'data:image/png;base64,{encoded_string}' | |
| def demo_lhm(flametracking): | |
| def core_fn(image: str, video_params, working_dir): | |
| image_raw = os.path.join(working_dir.name, 'raw.png') | |
| with Image.fromarray(image) as img: | |
| img.save(image_raw) | |
| base_vid = os.path.basename(video_params).split('_')[0] | |
| dump_video_path = os.path.join(working_dir.name, 'output.mp4') | |
| dump_image_path = os.path.join(working_dir.name, 'output.png') | |
| # prepare dump paths | |
| omit_prefix = os.path.dirname(image_raw) | |
| image_name = os.path.basename(image_raw) | |
| uid = image_name.split('.')[0] | |
| subdir_path = os.path.dirname(image_raw).replace(omit_prefix, '') | |
| subdir_path = (subdir_path[1:] | |
| if subdir_path.startswith('/') else subdir_path) | |
| print('==> subdir_path and uid:', subdir_path, uid) | |
| dump_image_dir = os.path.dirname(dump_image_path) | |
| os.makedirs(dump_image_dir, exist_ok=True) | |
| print('==> path:', image_raw, dump_image_dir, dump_video_path) | |
| dump_tmp_dir = dump_image_dir | |
| return_code = flametracking.preprocess(image_raw) | |
| return_code = flametracking.optimize() | |
| return_code, output_dir = flametracking.export() | |
| print("==> output_dir:", output_dir) | |
| save_ref_img_path = os.path.join(dump_tmp_dir, 'output.png') | |
| vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * | |
| 255).astype(np.uint8) | |
| Image.fromarray(vis_ref_img).save(save_ref_img_path) | |
| # rendering !!!! | |
| start_time = time.time() | |
| batch_dict = dict() | |
| rgb = cv2.imread(os.path.join(output_dir,'images/00000_00.png')) | |
| for i in range(30): | |
| images_to_video( | |
| rgb, | |
| output_path=dump_video_path, | |
| fps=30, | |
| gradio_codec=False, | |
| verbose=True, | |
| ) | |
| return dump_image_path, dump_video_path | |
| _TITLE = '''LHM: Large Animatable Human Model''' | |
| _DESCRIPTION = ''' | |
| <strong>Reconstruct a human avatar in 0.2 seconds with A100!</strong> | |
| ''' | |
| with gr.Blocks(analytics_enabled=False, delete_cache=[3600, 3600]) as demo: | |
| # </div> | |
| logo_url = './asset/logo.jpeg' | |
| logo_base64 = get_image_base64(logo_url) | |
| gr.HTML(f""" | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <div> | |
| <h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> Large Animatable Human Model </h1> | |
| </div> | |
| </div> | |
| """) | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;"> | |
| <a class="flex-item" href="https://arxiv.org/abs/2503.10625" target="_blank"> | |
| <img src="https://img.shields.io/badge/Paper-arXiv-darkred.svg" alt="arXiv Paper"> | |
| </a> | |
| <a class="flex-item" href="https://lingtengqiu.github.io/LHM/" target="_blank"> | |
| <img src="https://img.shields.io/badge/Project-LHM-blue" alt="Project Page"> | |
| </a> | |
| <a class="flex-item" href="https://github.com/aigc3d/LHM" target="_blank"> | |
| <img src="https://img.shields.io/github/stars/aigc3d/LHM?label=Github%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
| </a> | |
| <a class="flex-item" href="https://www.youtube.com/watch?v=tivEpz_yiEo" target="_blank"> | |
| <img src="https://img.shields.io/badge/Youtube-Video-red.svg" alt="Video"> | |
| </a> | |
| </div> | |
| """) | |
| gr.HTML( | |
| """<p><h4 style="color: red;"> Notes: Please input full-body image in case of detection errors. We simplify the pipeline in spaces: 1) using Rembg instead of SAM2; 2) limit the output video length to 10s; For best visual quality, try the inference code on Github instead.</h4></p>""" | |
| ) | |
| # DISPLAY | |
| with gr.Row(): | |
| with gr.Column(variant='panel', scale=1): | |
| with gr.Tabs(elem_id='openlrm_input_image'): | |
| with gr.TabItem('Input Image'): | |
| with gr.Row(): | |
| input_image = gr.Image(label='Input Image', | |
| image_mode='RGB', | |
| height=480, | |
| width=270, | |
| sources='upload', | |
| type='numpy', | |
| elem_id='content_image') | |
| # EXAMPLES | |
| with gr.Row(): | |
| examples = [ | |
| ['asset/sample_input/00000.png'], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_image], | |
| examples_per_page=10, | |
| ) | |
| with gr.Column(): | |
| with gr.Tabs(elem_id='openlrm_input_video'): | |
| with gr.TabItem('Input Video'): | |
| with gr.Row(): | |
| video_input = gr.Video(label='Input Video', | |
| height=480, | |
| width=270, | |
| interactive=False) | |
| examples = [ | |
| './asset/sample_input/demo.mp4', | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[video_input], | |
| examples_per_page=20, | |
| ) | |
| with gr.Column(variant='panel', scale=1): | |
| with gr.Tabs(elem_id='openlrm_processed_image'): | |
| with gr.TabItem('Processed Image'): | |
| with gr.Row(): | |
| processed_image = gr.Image( | |
| label='Processed Image', | |
| image_mode='RGB', | |
| type='filepath', | |
| elem_id='processed_image', | |
| height=480, | |
| width=270, | |
| interactive=False) | |
| with gr.Column(variant='panel', scale=1): | |
| with gr.Tabs(elem_id='openlrm_render_video'): | |
| with gr.TabItem('Rendered Video'): | |
| with gr.Row(): | |
| output_video = gr.Video(label='Rendered Video', | |
| format='mp4', | |
| height=480, | |
| width=270, | |
| autoplay=True) | |
| # SETTING | |
| with gr.Row(): | |
| with gr.Column(variant='panel', scale=1): | |
| submit = gr.Button('Generate', | |
| elem_id='openlrm_generate', | |
| variant='primary') | |
| working_dir = gr.State() | |
| submit.click( | |
| fn=assert_input_image, | |
| inputs=[input_image], | |
| queue=False, | |
| ).success( | |
| fn=prepare_working_dir, | |
| outputs=[working_dir], | |
| queue=False, | |
| ).success( | |
| fn=core_fn, | |
| inputs=[input_image, video_input, | |
| working_dir], # video_params refer to smpl dir | |
| outputs=[processed_image, output_video], | |
| ) | |
| demo.queue(max_size=1) | |
| demo.launch() | |
| def launch_gradio_app(): | |
| os.environ.update({ | |
| 'APP_ENABLED': '1', | |
| 'APP_MODEL_NAME': | |
| './exps/releases/video_human_benchmark/human-lrm-500M/step_060000/', | |
| 'APP_INFER': './configs/inference/human-lrm-500M.yaml', | |
| 'APP_TYPE': 'infer.human_lrm', | |
| 'NUMBA_THREADING_LAYER': 'omp', | |
| }) | |
| flametracking = FlameTrackingSingleImage(output_dir='tracking_output', | |
| alignment_model_path='./pretrain_model/68_keypoints_model.pkl', | |
| vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd', | |
| human_matting_path='./pretrain_model/matting/stylematte_synth.pt', | |
| facebox_model_path='./pretrain_model/FaceBoxesV2.pth', | |
| detect_iris_landmarks=True) | |
| demo_lhm(flametracking) | |
| if __name__ == '__main__': | |
| launch_pretrained() | |
| launch_gradio_app() | |