huanngzh's picture
update
bc6d7b0
raw
history blame
8.8 kB
import os
import random
import shutil
import subprocess
from typing import List
import gradio as gr
import numpy as np
import spaces
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from inference_ig2mv_sdxl import (
prepare_pipeline,
preprocess_image,
remove_bg,
run_pipeline,
)
from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image
# install others
subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16
MAX_SEED = np.iinfo(np.int32).max
NUM_VIEWS = 6
HEIGHT = 768
WIDTH = 768
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
os.makedirs(TMP_DIR, exist_ok=True)
HEADER = """
# 🔮 Image to Texture with [MV-Adapter](https://github.com/huanngzh/MV-Adapter)
## State-of-the-art Open Source Texture Generation Using Multi-View Diffusion Model
<p style="font-size: 1.1em;">By <a href="https://www.tripo3d.ai/" style="color: #1E90FF; text-decoration: none; font-weight: bold;">Tripo</a></p>
"""
EXAMPLES = [
["examples/001.jpeg", "examples/001.glb"],
["examples/002.jpeg", "examples/002.glb"],
]
# MV-Adapter
pipe = prepare_pipeline(
base_model="stabilityai/stable-diffusion-xl-base-1.0",
vae_model="madebyollin/sdxl-vae-fp16-fix",
unet_model=None,
lora_model=None,
adapter_path="huanngzh/mv-adapter",
scheduler=None,
num_views=NUM_VIEWS,
device=DEVICE,
dtype=DTYPE,
)
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(DEVICE)
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
hf_hub_download(
"dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints"
)
if not os.path.exists("checkpoints/big-lama.pt"):
subprocess.run(
"wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
shell=True,
check=True,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
def start_session(req: gr.Request):
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(save_dir, exist_ok=True)
print("start session, mkdir", save_dir)
def end_session(req: gr.Request):
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(save_dir)
def get_random_hex():
random_bytes = os.urandom(8)
random_hex = random_bytes.hex()
return random_hex
def get_random_seed(randomize_seed, seed):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU(duration=90)
@torch.no_grad()
def run_mvadapter(
mesh_path,
prompt,
image,
seed=42,
guidance_scale=3.0,
num_inference_steps=30,
reference_conditioning_scale=1.0,
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
progress=gr.Progress(track_tqdm=True),
):
# pre-process the reference image
image = Image.open(image).convert("RGB") if isinstance(image, str) else image
image = remove_bg_fn(image)
image = preprocess_image(image, HEIGHT, WIDTH)
if isinstance(seed, str):
try:
seed = int(seed.strip())
except ValueError:
seed = 42
images, _, _, _ = run_pipeline(
pipe,
mesh_path=mesh_path,
num_views=NUM_VIEWS,
text=prompt,
image=image,
height=HEIGHT,
width=WIDTH,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
remove_bg_fn=None,
reference_conditioning_scale=reference_conditioning_scale,
negative_prompt=negative_prompt,
device=DEVICE,
)
torch.cuda.empty_cache()
return images, image
@spaces.GPU(duration=90)
@torch.no_grad()
def run_texturing(
mesh_path: str,
mv_images: List[Image.Image],
uv_unwarp: bool,
preprocess_mesh: bool,
uv_size: int,
req: gr.Request,
):
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
mv_images = [item[0] for item in mv_images]
make_image_grid(mv_images, rows=1).save(mv_image_path)
from texture import ModProcessConfig, TexturePipeline
texture_pipe = TexturePipeline(
upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
inpaint_ckpt_path="checkpoints/big-lama.pt",
device=DEVICE,
)
textured_glb_path = texture_pipe(
mesh_path=mesh_path,
save_dir=save_dir,
save_name=f"texture_mesh_{get_random_hex()}",
uv_unwarp=uv_unwarp,
preprocess_mesh=preprocess_mesh,
uv_size=uv_size,
rgb_path=mv_image_path,
rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
).shaded_model_save_path
torch.cuda.empty_cache()
return textured_glb_path, textured_glb_path
with gr.Blocks(title="MVAdapter") as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column():
with gr.Row():
input_mesh = gr.Model3D(label="Input 3D mesh")
image_prompt = gr.Image(label="Input Image", type="pil")
with gr.Accordion("Generation Settings", open=False):
prompt = gr.Textbox(
label="Prompt (Optional)",
placeholder="Enter your prompt",
value="high quality",
)
seed = gr.Slider(
label="Seed", minimum=0, maximum=MAX_SEED, step=0, value=0
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=8,
maximum=50,
step=1,
value=25,
)
guidance_scale = gr.Slider(
label="CFG scale",
minimum=0.0,
maximum=20.0,
step=0.1,
value=3.0,
)
reference_conditioning_scale = gr.Slider(
label="Image conditioning scale",
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
)
with gr.Accordion("Texture Settings", open=False):
with gr.Row():
uv_unwarp = gr.Checkbox(label="Unwarp UV", value=True)
preprocess_mesh = gr.Checkbox(label="Preprocess Mesh", value=False)
uv_size = gr.Slider(
label="UV Size", minimum=1024, maximum=8192, step=512, value=4096
)
gen_button = gr.Button("Generate Texture", variant="primary")
examples = gr.Examples(
examples=EXAMPLES,
inputs=[image_prompt, input_mesh],
outputs=[image_prompt],
)
with gr.Column():
mv_result = gr.Gallery(
label="Multi-View Results",
show_label=False,
columns=[3],
rows=[2],
object_fit="contain",
height="auto",
type="pil",
)
textured_model_output = gr.Model3D(label="Textured GLB", interactive=False)
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
gen_button.click(
get_random_seed, inputs=[randomize_seed, seed], outputs=[seed]
).then(
run_mvadapter,
inputs=[
input_mesh,
prompt,
image_prompt,
seed,
guidance_scale,
num_inference_steps,
reference_conditioning_scale,
],
outputs=[mv_result, image_prompt],
).then(
run_texturing,
inputs=[input_mesh, mv_result, uv_unwarp, preprocess_mesh, uv_size],
outputs=[textured_model_output, download_glb],
).then(
lambda: gr.Button(interactive=True), outputs=[download_glb]
)
demo.load(start_session)
demo.unload(end_session)
demo.launch()