|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
os.system('rm -rf /data-nvme/zerogpu-offload/') |
|
os.system('pip install numpy==1.23.0') |
|
os.system('pip install ./wheels/pytorch3d-0.7.3-cp310-cp310-linux_x86_64.whl') |
|
|
|
import argparse |
|
import base64 |
|
import time |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
|
|
import gradio as gr |
|
import spaces |
|
from flame_tracking_single_image import FlameTrackingSingleImage |
|
from ffmpeg_utils import images_to_video |
|
|
|
|
|
|
|
|
|
def parse_configs(): |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', type=str) |
|
parser.add_argument('--infer', type=str) |
|
args, unknown = parser.parse_known_args() |
|
|
|
cfg = OmegaConf.create() |
|
cli_cfg = OmegaConf.from_cli(unknown) |
|
|
|
|
|
if os.environ.get('APP_INFER') is not None: |
|
args.infer = os.environ.get('APP_INFER') |
|
if os.environ.get('APP_MODEL_NAME') is not None: |
|
cli_cfg.model_name = os.environ.get('APP_MODEL_NAME') |
|
|
|
args.config = args.infer if args.config is None else args.config |
|
|
|
if args.config is not None: |
|
cfg_train = OmegaConf.load(args.config) |
|
cfg.source_size = cfg_train.dataset.source_image_res |
|
try: |
|
cfg.src_head_size = cfg_train.dataset.src_head_size |
|
except: |
|
cfg.src_head_size = 112 |
|
cfg.render_size = cfg_train.dataset.render_image.high |
|
_relative_path = os.path.join( |
|
cfg_train.experiment.parent, |
|
cfg_train.experiment.child, |
|
os.path.basename(cli_cfg.model_name).split('_')[-1], |
|
) |
|
|
|
cfg.save_tmp_dump = os.path.join('exps', 'save_tmp', _relative_path) |
|
cfg.image_dump = os.path.join('exps', 'images', _relative_path) |
|
cfg.video_dump = os.path.join('exps', 'videos', |
|
_relative_path) |
|
|
|
if args.infer is not None: |
|
cfg_infer = OmegaConf.load(args.infer) |
|
cfg.merge_with(cfg_infer) |
|
cfg.setdefault('save_tmp_dump', |
|
os.path.join('exps', cli_cfg.model_name, 'save_tmp')) |
|
cfg.setdefault('image_dump', |
|
os.path.join('exps', cli_cfg.model_name, 'images')) |
|
cfg.setdefault('video_dump', |
|
os.path.join('dumps', cli_cfg.model_name, 'videos')) |
|
cfg.setdefault('mesh_dump', |
|
os.path.join('dumps', cli_cfg.model_name, 'meshes')) |
|
|
|
cfg.motion_video_read_fps = 6 |
|
cfg.merge_with(cli_cfg) |
|
|
|
cfg.setdefault('logger', 'INFO') |
|
|
|
assert cfg.model_name is not None, 'model_name is required' |
|
|
|
return cfg, cfg_train |
|
|
|
|
|
|
|
def launch_pretrained(): |
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
hf_hub_download(repo_id='yuandong513/flametracking_model', |
|
repo_type='model', |
|
filename='pretrain_model.tar', |
|
local_dir='./') |
|
os.system('tar -xf pretrain_model.tar && rm pretrain_model.tar') |
|
|
|
def animation_infer(renderer, gs_model_list, query_points, smplx_params, |
|
render_c2ws, render_intrs, render_bg_colors): |
|
'''Inference code avoid repeat forward. |
|
''' |
|
render_h, render_w = int(render_intrs[0, 0, 1, 2] * 2), int( |
|
render_intrs[0, 0, 0, 2] * 2) |
|
|
|
render_res_list = [] |
|
num_views = render_c2ws.shape[1] |
|
start_time = time.time() |
|
|
|
|
|
render_res_list = [] |
|
|
|
for view_idx in range(num_views): |
|
render_res = renderer.forward_animate_gs( |
|
gs_model_list, |
|
query_points, |
|
renderer.get_single_view_smpl_data(smplx_params, view_idx), |
|
render_c2ws[:, view_idx:view_idx + 1], |
|
render_intrs[:, view_idx:view_idx + 1], |
|
render_h, |
|
render_w, |
|
render_bg_colors[:, view_idx:view_idx + 1], |
|
) |
|
render_res_list.append(render_res) |
|
print( |
|
f'time elpased(animate gs model per frame):{(time.time() - start_time)/num_views}' |
|
) |
|
|
|
out = defaultdict(list) |
|
for res in render_res_list: |
|
for k, v in res.items(): |
|
if isinstance(v[0], torch.Tensor): |
|
out[k].append(v.detach().cpu()) |
|
else: |
|
out[k].append(v) |
|
for k, v in out.items(): |
|
|
|
if isinstance(v[0], torch.Tensor): |
|
out[k] = torch.concat(v, dim=1) |
|
if k in ['comp_rgb', 'comp_mask', 'comp_depth']: |
|
out[k] = out[k][0].permute( |
|
0, 2, 3, |
|
1) |
|
else: |
|
out[k] = v |
|
return out |
|
|
|
|
|
def assert_input_image(input_image): |
|
if input_image is None: |
|
raise gr.Error('No image selected or uploaded!') |
|
|
|
|
|
def prepare_working_dir(): |
|
import tempfile |
|
working_dir = tempfile.TemporaryDirectory() |
|
return working_dir |
|
|
|
def get_image_base64(path): |
|
with open(path, 'rb') as image_file: |
|
encoded_string = base64.b64encode(image_file.read()).decode() |
|
return f'data:image/png;base64,{encoded_string}' |
|
|
|
|
|
def demo_lhm(flametracking): |
|
@spaces.GPU(duration=80) |
|
def core_fn(image: str, video_params, working_dir): |
|
image_raw = os.path.join(working_dir.name, 'raw.png') |
|
with Image.fromarray(image) as img: |
|
img.save(image_raw) |
|
|
|
base_vid = os.path.basename(video_params).split('_')[0] |
|
|
|
dump_video_path = os.path.join(working_dir.name, 'output.mp4') |
|
dump_image_path = os.path.join(working_dir.name, 'output.png') |
|
|
|
|
|
omit_prefix = os.path.dirname(image_raw) |
|
image_name = os.path.basename(image_raw) |
|
uid = image_name.split('.')[0] |
|
subdir_path = os.path.dirname(image_raw).replace(omit_prefix, '') |
|
subdir_path = (subdir_path[1:] |
|
if subdir_path.startswith('/') else subdir_path) |
|
print('==> subdir_path and uid:', subdir_path, uid) |
|
|
|
dump_image_dir = os.path.dirname(dump_image_path) |
|
os.makedirs(dump_image_dir, exist_ok=True) |
|
|
|
print('==> path:', image_raw, dump_image_dir, dump_video_path) |
|
|
|
dump_tmp_dir = dump_image_dir |
|
|
|
return_code = flametracking.preprocess(image_raw) |
|
return_code = flametracking.optimize() |
|
return_code, output_dir = flametracking.export() |
|
|
|
print("==> output_dir:", output_dir) |
|
|
|
|
|
save_ref_img_path = os.path.join(dump_tmp_dir, 'output.png') |
|
vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * |
|
255).astype(np.uint8) |
|
Image.fromarray(vis_ref_img).save(save_ref_img_path) |
|
|
|
|
|
start_time = time.time() |
|
batch_dict = dict() |
|
|
|
rgb = cv2.imread(os.path.join(output_dir,'images/00000_00.png')) |
|
|
|
for i in range(30): |
|
images_to_video( |
|
rgb, |
|
output_path=dump_video_path, |
|
fps=30, |
|
gradio_codec=False, |
|
verbose=True, |
|
) |
|
|
|
return dump_image_path, dump_video_path |
|
|
|
_TITLE = '''LHM: Large Animatable Human Model''' |
|
|
|
_DESCRIPTION = ''' |
|
<strong>Reconstruct a human avatar in 0.2 seconds with A100!</strong> |
|
''' |
|
|
|
with gr.Blocks(analytics_enabled=False, delete_cache=[3600, 3600]) as demo: |
|
|
|
|
|
logo_url = './asset/logo.jpeg' |
|
logo_base64 = get_image_base64(logo_url) |
|
gr.HTML(f""" |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> |
|
<div> |
|
<h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> Large Animatable Human Model </h1> |
|
</div> |
|
</div> |
|
""") |
|
|
|
gr.HTML(""" |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;"> |
|
<a class="flex-item" href="https://arxiv.org/abs/2503.10625" target="_blank"> |
|
<img src="https://img.shields.io/badge/Paper-arXiv-darkred.svg" alt="arXiv Paper"> |
|
</a> |
|
<a class="flex-item" href="https://lingtengqiu.github.io/LHM/" target="_blank"> |
|
<img src="https://img.shields.io/badge/Project-LHM-blue" alt="Project Page"> |
|
</a> |
|
<a class="flex-item" href="https://github.com/aigc3d/LHM" target="_blank"> |
|
<img src="https://img.shields.io/github/stars/aigc3d/LHM?label=Github%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> |
|
</a> |
|
<a class="flex-item" href="https://www.youtube.com/watch?v=tivEpz_yiEo" target="_blank"> |
|
<img src="https://img.shields.io/badge/Youtube-Video-red.svg" alt="Video"> |
|
</a> |
|
</div> |
|
""") |
|
|
|
gr.HTML( |
|
"""<p><h4 style="color: red;"> Notes: Please input full-body image in case of detection errors. We simplify the pipeline in spaces: 1) using Rembg instead of SAM2; 2) limit the output video length to 10s; For best visual quality, try the inference code on Github instead.</h4></p>""" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(variant='panel', scale=1): |
|
with gr.Tabs(elem_id='openlrm_input_image'): |
|
with gr.TabItem('Input Image'): |
|
with gr.Row(): |
|
input_image = gr.Image(label='Input Image', |
|
image_mode='RGB', |
|
height=480, |
|
width=270, |
|
sources='upload', |
|
type='numpy', |
|
elem_id='content_image') |
|
|
|
with gr.Row(): |
|
examples = [ |
|
['asset/sample_input/00000.png'], |
|
] |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[input_image], |
|
examples_per_page=10, |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Tabs(elem_id='openlrm_input_video'): |
|
with gr.TabItem('Input Video'): |
|
with gr.Row(): |
|
video_input = gr.Video(label='Input Video', |
|
height=480, |
|
width=270, |
|
interactive=False) |
|
|
|
examples = [ |
|
'./asset/sample_input/demo.mp4', |
|
] |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[video_input], |
|
examples_per_page=20, |
|
) |
|
with gr.Column(variant='panel', scale=1): |
|
with gr.Tabs(elem_id='openlrm_processed_image'): |
|
with gr.TabItem('Processed Image'): |
|
with gr.Row(): |
|
processed_image = gr.Image( |
|
label='Processed Image', |
|
image_mode='RGB', |
|
type='filepath', |
|
elem_id='processed_image', |
|
height=480, |
|
width=270, |
|
interactive=False) |
|
|
|
with gr.Column(variant='panel', scale=1): |
|
with gr.Tabs(elem_id='openlrm_render_video'): |
|
with gr.TabItem('Rendered Video'): |
|
with gr.Row(): |
|
output_video = gr.Video(label='Rendered Video', |
|
format='mp4', |
|
height=480, |
|
width=270, |
|
autoplay=True) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(variant='panel', scale=1): |
|
submit = gr.Button('Generate', |
|
elem_id='openlrm_generate', |
|
variant='primary') |
|
|
|
working_dir = gr.State() |
|
submit.click( |
|
fn=assert_input_image, |
|
inputs=[input_image], |
|
queue=False, |
|
).success( |
|
fn=prepare_working_dir, |
|
outputs=[working_dir], |
|
queue=False, |
|
).success( |
|
fn=core_fn, |
|
inputs=[input_image, video_input, |
|
working_dir], |
|
outputs=[processed_image, output_video], |
|
) |
|
|
|
demo.queue(max_size=1) |
|
demo.launch() |
|
|
|
|
|
def launch_gradio_app(): |
|
|
|
os.environ.update({ |
|
'APP_ENABLED': '1', |
|
'APP_MODEL_NAME': |
|
'./exps/releases/video_human_benchmark/human-lrm-500M/step_060000/', |
|
'APP_INFER': './configs/inference/human-lrm-500M.yaml', |
|
'APP_TYPE': 'infer.human_lrm', |
|
'NUMBA_THREADING_LAYER': 'omp', |
|
}) |
|
|
|
flametracking = FlameTrackingSingleImage(output_dir='tracking_output', |
|
alignment_model_path='./pretrain_model/68_keypoints_model.pkl', |
|
vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd', |
|
human_matting_path='./pretrain_model/matting/stylematte_synth.pt', |
|
facebox_model_path='./pretrain_model/FaceBoxesV2.pth', |
|
detect_iris_landmarks=True) |
|
|
|
|
|
demo_lhm(flametracking) |
|
|
|
|
|
if __name__ == '__main__': |
|
launch_pretrained() |
|
launch_gradio_app() |
|
|
|
|