Spaces:
dylanebert
/
Running on Zero

LGM-mini / app.py
dylanebert's picture
dylanebert HF staff
encapsulate in pipeline
26b3ad9
raw
history blame
3.45 kB
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from gradio_client import Client, file
try:
import diff_gaussian_rasterization # noqa: F401
except ImportError:
os.system("pip install ./diff-gaussian-rasterization")
TMP_DIR = "/tmp"
os.makedirs(TMP_DIR, exist_ok=True)
image_pipeline = DiffusionPipeline.from_pretrained(
"ashawkey/imagedream-ipmv-diffusers",
custom_pipeline="dylanebert/multi_view_diffusion",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
splat_pipeline = DiffusionPipeline.from_pretrained(
"dylanebert/LGM",
custom_pipeline="dylanebert/LGM",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
def run(input_image, convert):
input_image = input_image.astype("float32") / 255.0
images = image_pipeline(
"", input_image, guidance_scale=5, num_inference_steps=30, elevation=0
)
gaussians = splat_pipeline(images)
output_ply_path = os.path.join(TMP_DIR, "output.ply")
splat_pipeline.save_ply(gaussians, output_ply_path)
if convert:
output_mesh_path = convert_to_mesh(output_ply_path)
return output_mesh_path
else:
return output_ply_path
def convert_to_mesh(input_ply):
client = Client("https://dylanebert-splat-to-mesh.hf.space/")
output_mesh_path = client.predict(file(input_ply), api_name="/run")
client.close()
return output_mesh_path
_TITLE = """LGM Mini"""
_DESCRIPTION = """
<div>
A lightweight version of <a href="https://huggingface.co/spaces/ashawkey/LGM">LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation</a>.
</div>
"""
css = """
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
block = gr.Blocks(title=_TITLE, css=css)
with block:
gr.DuplicateButton(
value="Duplicate Space for private use", elem_id="duplicate-button"
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("# " + _TITLE)
gr.Markdown(_DESCRIPTION)
with gr.Row(variant="panel"):
with gr.Column(scale=1):
def update_warning(checked):
if checked:
return '<span style="color: #ff0000;">Warning: Mesh conversion takes several minutes</span>'
else:
return ""
input_image = gr.Image(label="image", type="numpy")
convert_checkbox = gr.Checkbox(label="Convert to Mesh")
warning = gr.HTML()
convert_checkbox.change(
fn=update_warning, inputs=[convert_checkbox], outputs=[warning]
)
button_gen = gr.Button("Generate")
with gr.Column(scale=1):
output_splat = gr.Model3D(label="3D Gaussians")
button_gen.click(
fn=run, inputs=[input_image, convert_checkbox], outputs=[output_splat]
)
gr.Examples(
examples=[
"data_test/frog_sweater.jpg",
"data_test/bird.jpg",
"data_test/boy.jpg",
"data_test/cat_statue.jpg",
"data_test/dragontoy.jpg",
"data_test/gso_rabbit.jpg",
],
inputs=[input_image],
outputs=[output_splat],
fn=lambda x: run(input_image=x, convert=False),
cache_examples=True,
label="Image-to-3D Examples",
)
block.queue().launch(debug=True, share=True)