Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import shutil | |
from functools import partial | |
import gradio as gr | |
from common import ( | |
MAX_SEED, | |
VERSION, | |
TrellisImageTo3DPipeline, | |
active_btn_by_content, | |
extract_3d_representations_v2, | |
extract_urdf, | |
get_seed, | |
image_to_3d, | |
preprocess_image_fn, | |
preprocess_sam_image_fn, | |
select_point, | |
) | |
from gradio.themes import Default | |
from gradio.themes.utils.colors import slate | |
from gradio_litmodel3d import LitModel3D | |
from asset3d_gen.models.delight import DelightingModel | |
from asset3d_gen.models.segment import RembgRemover, SAMPredictor | |
from asset3d_gen.models.super_resolution import ImageRealESRGAN | |
from asset3d_gen.utils.gpt_clients import GPT_CLIENT | |
from asset3d_gen.validators.quality_checkers import ( | |
ImageAestheticChecker, | |
ImageSegChecker, | |
MeshGeoChecker, | |
) | |
from asset3d_gen.validators.urdf_convertor import URDFGenerator | |
TMP_DIR = os.path.join( | |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d" | |
) | |
os.makedirs(TMP_DIR, exist_ok=True) | |
RBG_REMOVER = RembgRemover() | |
SAM_PREDICTOR = SAMPredictor(model_type="vit_h") | |
DELIGHT = DelightingModel() | |
IMAGESR_MODEL = ImageRealESRGAN(outscale=4) | |
PIPELINE = TrellisImageTo3DPipeline.from_pretrained( | |
"JeffreyXiang/TRELLIS-image-large" | |
) | |
# PIPELINE.cuda() | |
IMAGE_BUFFER = {} | |
SEG_CHECKER = ImageSegChecker(GPT_CLIENT) | |
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) | |
AESTHETIC_CHECKER = ImageAestheticChecker() | |
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] | |
URDF_CONVERTOR = URDFGenerator(GPT_CLIENT, render_view_num=4) | |
def start_session(req: gr.Request) -> None: | |
user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
os.makedirs(user_dir, exist_ok=True) | |
def end_session(req: gr.Request) -> None: | |
user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
if os.path.exists(user_dir): | |
shutil.rmtree(user_dir) | |
with gr.Blocks( | |
delete_cache=(43200, 43200), theme=Default(primary_hue=slate) | |
) as demo: | |
gr.Markdown( | |
f""" | |
## Image to 3D Asset Pipeline \n | |
version: {VERSION} \n | |
The service is temporarily deployed on `dev015-10.34.8.82: CUDA 4`. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Tabs() as input_tabs: | |
with gr.Tab( | |
label="Image(auto seg)", id=0 | |
) as single_image_input_tab: | |
image_prompt = gr.Image( | |
label="Input Image", | |
format="png", | |
image_mode="RGBA", | |
type="pil", | |
height=300, | |
) | |
gr.Markdown( | |
""" | |
If you are not satisfied with the auto segmentation | |
result, please switch to the `Image(SAM seg)` tab.""" | |
) | |
with gr.Tab( | |
label="Image(SAM seg)", id=1 | |
) as samimage_input_tab: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_prompt_sam = gr.Image( | |
label="Input Image", type="numpy", height=400 | |
) | |
image_seg_sam = gr.Image( | |
label="SAM Seg Image", | |
image_mode="RGBA", | |
type="pil", | |
height=400, | |
visible=False, | |
) | |
with gr.Column(scale=1): | |
image_mask_sam = gr.AnnotatedImage() | |
fg_bg_radio = gr.Radio( | |
["foreground_point", "background_point"], | |
label="Select foreground(green) or background(red) points, by default foreground", # noqa | |
value="foreground_point", | |
) | |
gr.Markdown( | |
""" Click the `Input Image` to select SAM points, | |
after get the satisified segmentation, click `Generate` | |
button to generate the 3D asset. \n | |
Note: If the segmented foreground is too small relative | |
to the entire image area, the generation will fail. | |
""" | |
) | |
with gr.Accordion(label="Generation Settings", open=False): | |
with gr.Row(): | |
seed = gr.Slider( | |
0, MAX_SEED, label="Seed", value=0, step=1 | |
) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox( | |
label="Randomize Seed", value=False | |
) | |
project_delight = gr.Checkbox( | |
label="Backproject delighting", | |
value=True, | |
) | |
gr.Markdown("Geo Structure Generation") | |
with gr.Row(): | |
ss_guidance_strength = gr.Slider( | |
0.0, | |
10.0, | |
label="Guidance Strength", | |
value=7.5, | |
step=0.1, | |
) | |
ss_sampling_steps = gr.Slider( | |
1, 50, label="Sampling Steps", value=12, step=1 | |
) | |
gr.Markdown("Visual Appearance Generation") | |
with gr.Row(): | |
slat_guidance_strength = gr.Slider( | |
0.0, | |
10.0, | |
label="Guidance Strength", | |
value=3.0, | |
step=0.1, | |
) | |
slat_sampling_steps = gr.Slider( | |
1, 50, label="Sampling Steps", value=12, step=1 | |
) | |
generate_btn = gr.Button( | |
"Generate(~0.5 mins)", variant="primary", interactive=False | |
) | |
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False) | |
with gr.Row(): | |
extract_rep3d_btn = gr.Button( | |
"Extract 3D Representation(~2 mins)", | |
variant="primary", | |
interactive=False, | |
) | |
with gr.Accordion( | |
label="Enter Asset Attributes(optional)", open=False | |
): | |
asset_cat_text = gr.Textbox( | |
label="Enter Asset Category (e.g., chair)" | |
) | |
height_range_text = gr.Textbox( | |
label="Enter Height Range in meter (e.g., 0.5-0.6)" | |
) | |
mass_range_text = gr.Textbox( | |
label="Enter Mass Range in kg (e.g., 1.1-1.2)" | |
) | |
asset_version_text = gr.Textbox( | |
label=f"Enter version (e.g., {VERSION})" | |
) | |
with gr.Row(): | |
extract_urdf_btn = gr.Button( | |
"Extract URDF(~1 mins)", | |
variant="primary", | |
interactive=False, | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"#### Estimated Asset 3D Attributes(No input required)" | |
) | |
with gr.Row(): | |
est_type_text = gr.Textbox( | |
label="Asset category", interactive=False | |
) | |
est_height_text = gr.Textbox( | |
label="Real height(.m)", interactive=False | |
) | |
est_mass_text = gr.Textbox( | |
label="Mass(.kg)", interactive=False | |
) | |
est_mu_text = gr.Textbox( | |
label="Friction coefficient", interactive=False | |
) | |
with gr.Row(): | |
download_urdf = gr.DownloadButton( | |
label="Download URDF", variant="primary", interactive=False | |
) | |
gr.Markdown( | |
""" NOTE: If `Asset Attributes` are provided, the provided | |
properties will be used; otherwise, the GPT-preset properties | |
will be applied. \n | |
The `Download URDF` file is restored to the real scale and | |
has quality inspection, open with an editor to view details. | |
""" | |
) | |
with gr.Row() as single_image_example: | |
examples = gr.Examples( | |
label="Image Gallery", | |
examples=[ | |
[f"assets/example_image/{image}"] | |
for image in os.listdir( | |
"assets/example_image" | |
) | |
], | |
inputs=[image_prompt], | |
fn=partial( | |
preprocess_image_fn, | |
model=RBG_REMOVER, | |
buffer=IMAGE_BUFFER, | |
), | |
outputs=[image_prompt], | |
run_on_click=True, | |
examples_per_page=32, | |
) | |
with gr.Row(visible=False) as single_sam_image_example: | |
examples = gr.Examples( | |
label="Image Gallery", | |
examples=[ | |
f"assets/example_image/{image}" | |
for image in os.listdir( | |
"assets/example_image" | |
) | |
], | |
inputs=[image_prompt_sam], | |
fn=partial( | |
preprocess_sam_image_fn, | |
buffer=IMAGE_BUFFER, | |
model=SAM_PREDICTOR, | |
), | |
outputs=[image_prompt_sam], | |
run_on_click=True, | |
examples_per_page=32, | |
) | |
with gr.Column(scale=1): | |
video_output = gr.Video( | |
label="Generated 3D Asset", | |
autoplay=True, | |
loop=True, | |
height=300, | |
) | |
model_output_gs = LitModel3D( | |
label="Gaussian Representation", height=300, interactive=False | |
) | |
aligned_gs = gr.Textbox(visible=False) | |
with gr.Row(): | |
model_output_mesh = LitModel3D( | |
label="Mesh Representation", | |
exposure=10.0, | |
height=300, | |
interactive=False, | |
) | |
gr.Markdown( | |
""" The rendering of `Gaussian Representation` takes additional 10s. """ # noqa | |
) | |
is_samimage = gr.State(False) | |
output_buf = gr.State() | |
selected_points = gr.State(value=[]) | |
demo.load(start_session) | |
demo.unload(end_session) | |
single_image_input_tab.select( | |
lambda: tuple( | |
[False, gr.Row.update(visible=True), gr.Row.update(visible=False)] | |
), | |
outputs=[is_samimage, single_image_example, single_sam_image_example], | |
) | |
samimage_input_tab.select( | |
lambda: tuple( | |
[True, gr.Row.update(visible=True), gr.Row.update(visible=False)] | |
), | |
outputs=[is_samimage, single_sam_image_example, single_image_example], | |
) | |
image_prompt.upload( | |
partial(preprocess_image_fn, model=RBG_REMOVER, buffer=IMAGE_BUFFER), | |
inputs=[image_prompt], | |
outputs=[image_prompt], | |
) | |
image_prompt.change( | |
lambda: tuple( | |
[ | |
gr.Button(interactive=False), | |
gr.Button(interactive=False), | |
gr.Button(interactive=False), | |
None, | |
"", | |
None, | |
None, | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
] | |
), | |
outputs=[ | |
extract_rep3d_btn, | |
extract_urdf_btn, | |
download_urdf, | |
model_output_gs, | |
aligned_gs, | |
model_output_mesh, | |
video_output, | |
asset_cat_text, | |
height_range_text, | |
mass_range_text, | |
asset_version_text, | |
est_type_text, | |
est_height_text, | |
est_mass_text, | |
est_mu_text, | |
], | |
) | |
image_prompt.change( | |
active_btn_by_content, | |
inputs=image_prompt, | |
outputs=generate_btn, | |
) | |
image_prompt_sam.upload( | |
partial( | |
preprocess_sam_image_fn, buffer=IMAGE_BUFFER, model=SAM_PREDICTOR | |
), | |
inputs=[image_prompt_sam], | |
outputs=[image_prompt_sam], | |
) | |
image_prompt_sam.change( | |
lambda: tuple( | |
[ | |
gr.Button(interactive=False), | |
gr.Button(interactive=False), | |
gr.Button(interactive=False), | |
None, | |
None, | |
None, | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
None, | |
[], | |
] | |
), | |
outputs=[ | |
extract_rep3d_btn, | |
extract_urdf_btn, | |
download_urdf, | |
model_output_gs, | |
model_output_mesh, | |
video_output, | |
asset_cat_text, | |
height_range_text, | |
mass_range_text, | |
asset_version_text, | |
est_type_text, | |
est_height_text, | |
est_mass_text, | |
est_mu_text, | |
image_mask_sam, | |
selected_points, | |
], | |
) | |
image_prompt_sam.select( | |
select_point, | |
[ | |
image_prompt_sam, | |
selected_points, | |
fg_bg_radio, | |
gr.State(lambda: SAM_PREDICTOR), | |
], | |
[image_mask_sam, image_seg_sam], | |
) | |
image_seg_sam.change( | |
active_btn_by_content, | |
inputs=image_seg_sam, | |
outputs=generate_btn, | |
) | |
generate_btn.click( | |
get_seed, | |
inputs=[randomize_seed, seed], | |
outputs=[seed], | |
).success( | |
image_to_3d, | |
inputs=[ | |
image_prompt, | |
seed, | |
ss_guidance_strength, | |
ss_sampling_steps, | |
slat_guidance_strength, | |
slat_sampling_steps, | |
gr.State(lambda: IMAGE_BUFFER), | |
gr.State(lambda: PIPELINE), | |
gr.State(lambda: TMP_DIR), | |
image_seg_sam, | |
is_samimage, | |
], | |
outputs=[output_buf, video_output], | |
).success( | |
lambda: gr.Button(interactive=True), | |
outputs=[extract_rep3d_btn], | |
) | |
extract_rep3d_btn.click( | |
extract_3d_representations_v2, | |
inputs=[ | |
output_buf, | |
project_delight, | |
gr.State(lambda: TMP_DIR), | |
gr.State(lambda: DELIGHT), | |
gr.State(lambda: IMAGESR_MODEL), | |
], | |
outputs=[ | |
model_output_mesh, | |
model_output_gs, | |
model_output_obj, | |
aligned_gs, | |
], | |
).success( | |
lambda: gr.Button(interactive=True), | |
outputs=[extract_urdf_btn], | |
) | |
extract_urdf_btn.click( | |
extract_urdf, | |
inputs=[ | |
aligned_gs, | |
model_output_obj, | |
asset_cat_text, | |
height_range_text, | |
mass_range_text, | |
asset_version_text, | |
gr.State(lambda: TMP_DIR), | |
gr.State(lambda: URDF_CONVERTOR), | |
gr.State(lambda: IMAGE_BUFFER), | |
gr.State(lambda: CHECKERS), | |
], | |
outputs=[ | |
download_urdf, | |
est_type_text, | |
est_height_text, | |
est_mass_text, | |
est_mu_text, | |
], | |
queue=True, | |
show_progress="full", | |
).success( | |
lambda: gr.Button(interactive=True), | |
outputs=[download_urdf], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |