|
import os |
|
import imageio |
|
import numpy as np |
|
|
|
os.system("bash install.sh") |
|
|
|
from omegaconf import OmegaConf |
|
import tqdm |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as TF |
|
import rembg |
|
import gradio as gr |
|
from dva.io import load_from_config |
|
from dva.ray_marcher import RayMarcher |
|
from dva.visualize import visualize_primvolume, visualize_video_primvolume |
|
from inference import remove_background, resize_foreground, extract_texmesh |
|
from models.diffusion import create_diffusion |
|
from huggingface_hub import hf_hub_download |
|
ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_sview_dit_fp16.pt") |
|
vae_ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_vae_fp16.pt") |
|
|
|
GRADIO_PRIM_VIDEO_PATH = 'prim.mp4' |
|
GRADIO_RGB_VIDEO_PATH = 'rgb.mp4' |
|
GRADIO_MAT_VIDEO_PATH = 'mat.mp4' |
|
GRADIO_GLB_PATH = 'pbr_mesh.glb' |
|
CONFIG_PATH = "./configs/inference_dit.yml" |
|
|
|
config = OmegaConf.load(CONFIG_PATH) |
|
config.checkpoint_path = ckpt_path |
|
config.model.vae_checkpoint_path = vae_ckpt_path |
|
|
|
model = load_from_config(config.model.generator) |
|
state_dict = torch.load(config.checkpoint_path, map_location='cpu') |
|
model.load_state_dict(state_dict['ema']) |
|
vae = load_from_config(config.model.vae) |
|
vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu') |
|
vae.load_state_dict(vae_state_dict['model_state_dict']) |
|
conditioner = load_from_config(config.model.conditioner) |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
vae = vae.to(device) |
|
conditioner = conditioner.to(device) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
amp = True |
|
precision_dtype = torch.float16 |
|
|
|
rm = RayMarcher( |
|
config.image_height, |
|
config.image_width, |
|
**config.rm, |
|
).to(device) |
|
|
|
perchannel_norm = False |
|
if "latent_mean" in config.model: |
|
latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device) |
|
latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device) |
|
assert latent_mean.shape[-1] == config.model.generator.in_channels |
|
perchannel_norm = True |
|
|
|
config.diffusion.pop("timestep_respacing") |
|
config.model.pop("vae") |
|
config.model.pop("vae_checkpoint_path") |
|
config.model.pop("conditioner") |
|
config.model.pop("generator") |
|
config.model.pop("latent_nf") |
|
config.model.pop("latent_mean") |
|
config.model.pop("latent_std") |
|
model_primx = load_from_config(config.model) |
|
|
|
rembg_session = rembg.new_session() |
|
|
|
|
|
def process(input_image, input_num_steps=25, input_seed=42, input_cfg=6.0): |
|
|
|
torch.manual_seed(input_seed) |
|
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
output_rgb_video_path = os.path.join(config.output_dir, GRADIO_RGB_VIDEO_PATH) |
|
output_prim_video_path = os.path.join(config.output_dir, GRADIO_PRIM_VIDEO_PATH) |
|
output_mat_video_path = os.path.join(config.output_dir, GRADIO_MAT_VIDEO_PATH) |
|
output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH) |
|
|
|
diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion) |
|
sample_fn = diffusion.ddim_sample_loop_progressive |
|
fwd_fn = model.forward_with_cfg |
|
|
|
|
|
if input_image is None: |
|
raise NotImplementedError |
|
|
|
else: |
|
input_image = remove_background(input_image, rembg_session) |
|
input_image = resize_foreground(input_image, 0.85) |
|
raw_image = np.array(input_image) |
|
mask = (raw_image[..., -1][..., None] > 0) * 1 |
|
raw_image = raw_image[..., :3] * mask |
|
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device) |
|
|
|
with torch.no_grad(): |
|
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4) |
|
batch = {} |
|
inf_bs = 1 |
|
inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device) |
|
y = conditioner.encoder(input_cond) |
|
model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp) |
|
if input_cfg >= 0: |
|
model_kwargs['cfg_scale'] = input_cfg |
|
for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device): |
|
final_samples = samples |
|
recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1) |
|
if perchannel_norm: |
|
recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean |
|
recon_srt_param = recon_param[:, :, 0:4] |
|
recon_feat_param = recon_param[:, :, 4:] |
|
recon_feat_param_list = [] |
|
|
|
for inf_bidx in range(inf_bs): |
|
if not perchannel_norm: |
|
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf) |
|
else: |
|
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:])) |
|
recon_feat_param_list.append(decoded.detach()) |
|
recon_feat_param = torch.concat(recon_feat_param_list, dim=0) |
|
|
|
if not perchannel_norm: |
|
recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05 |
|
recon_feat_param[:, 0:1, ...] /= 5. |
|
recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2. |
|
recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1) |
|
recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1) |
|
visualize_video_primvolume(config.output_dir, batch, recon_param, 60, rm, device) |
|
prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()} |
|
torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(config.output_dir)) |
|
|
|
|
|
denoise_param_path = os.path.join(config.output_dir, 'denoised.pt') |
|
primx_ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict'] |
|
model_primx.load_state_dict(ckpt_weight) |
|
model_primx.to(device) |
|
model_primx.eval() |
|
with torch.no_grad(): |
|
model_primx.srt_param[:, 1:4] *= 0.85 |
|
extract_texmesh(config.inference, model_primx, output_glb_path, device) |
|
|
|
return output_rgb_video_path, output_prim_video_path, output_mat_video_path, output_glb_path |
|
|
|
|
|
_TITLE = '''3DTopia-XL''' |
|
|
|
_DESCRIPTION = ''' |
|
<div> |
|
<a style="display:inline-block" href="https://frozenburning.github.io/projects/3DTopia-XL/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a> |
|
<a style="display:inline-block; margin-left: .5em" href="https://github.com/3DTopia/3DTopia-XL"><img src='https://img.shields.io/github/stars/3DTopia/3DTopia-XL?style=social'/></a> |
|
</div> |
|
|
|
* Now we offer 1) single image conditioned model, we will release 2) multiview images conditioned model and 3) pure text conditioned model in the future! |
|
* If you find the output unsatisfying, try using different seeds! |
|
''' |
|
|
|
block = gr.Blocks(title=_TITLE).queue() |
|
with block: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown('# ' + _TITLE) |
|
gr.Markdown(_DESCRIPTION) |
|
|
|
with gr.Row(variant='panel'): |
|
with gr.Column(scale=1): |
|
|
|
input_image = gr.Image(label="image", type='pil') |
|
|
|
input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=25) |
|
|
|
input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=1, value=6) |
|
|
|
input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=42) |
|
|
|
button_gen = gr.Button("Generate") |
|
|
|
with gr.Column(scale=1): |
|
with gr.Tab("Video"): |
|
|
|
output_rgb_video = gr.Video(label="video") |
|
output_prim_video = gr.Video(label="video") |
|
output_mat_video = gr.Video(label="video") |
|
with gr.Tab("GLB"): |
|
|
|
output_glb = gr.File(label="glb") |
|
|
|
button_gen.click(process, inputs=[input_image, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb]) |
|
|
|
gr.Examples( |
|
examples=[ |
|
"assets/examples/fruit_elephant.jpg", |
|
"assets/examples/mei_ling_panda.png", |
|
"assets/examples/shuai_panda_notail.png", |
|
], |
|
inputs=[input_image], |
|
outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb], |
|
fn=lambda x: process(input_image=x), |
|
cache_examples=False, |
|
label='Single Image to 3D PBR Asset' |
|
) |
|
|
|
block.launch(server_name="0.0.0.0", share=True) |