import os os.environ["GRADIO_APP"] = "imageto3d" import gradio as gr from common import ( MAX_SEED, VERSION, active_btn_by_content, end_session, extract_3d_representations_v2, extract_urdf, get_seed, image_to_3d, preprocess_image_fn, preprocess_sam_image_fn, select_point, start_session, ) from gradio.themes import Default from gradio.themes.utils.colors import slate with gr.Blocks( delete_cache=(43200, 43200), theme=Default(primary_hue=slate) ) as demo: gr.Markdown( f""" ## Image to 3D Asset Pipeline \n version: {VERSION} \n The service is temporarily deployed on `dev015-10.34.8.82: CUDA 4`. """ ) with gr.Row(): with gr.Column(scale=2): with gr.Tabs() as input_tabs: with gr.Tab( label="Image(auto seg)", id=0 ) as single_image_input_tab: raw_image_cache = gr.Image( format="png", image_mode="RGB", type="pil", visible=False, ) image_prompt = gr.Image( label="Input Image", format="png", image_mode="RGBA", type="pil", height=400, ) gr.Markdown( """ If you are not satisfied with the auto segmentation result, please switch to the `Image(SAM seg)` tab.""" ) with gr.Tab( label="Image(SAM seg)", id=1 ) as samimage_input_tab: with gr.Row(): with gr.Column(scale=1): image_prompt_sam = gr.Image( label="Input Image", type="numpy", height=400 ) image_seg_sam = gr.Image( label="SAM Seg Image", image_mode="RGBA", type="pil", height=400, visible=False, ) with gr.Column(scale=1): image_mask_sam = gr.AnnotatedImage() fg_bg_radio = gr.Radio( ["foreground_point", "background_point"], label="Select foreground(green) or background(red) points, by default foreground", # noqa value="foreground_point", ) gr.Markdown( """ Click the `Input Image` to select SAM points, after get the satisified segmentation, click `Generate` button to generate the 3D asset. \n Note: If the segmented foreground is too small relative to the entire image area, the generation will fail. """ ) with gr.Accordion(label="Generation Settings", open=False): with gr.Row(): seed = gr.Slider( 0, MAX_SEED, label="Seed", value=0, step=1 ) with gr.Row(): randomize_seed = gr.Checkbox( label="Randomize Seed", value=False ) project_delight = gr.Checkbox( label="Backproject delighting", value=True, ) gr.Markdown("Geo Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider( 0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1, ) ss_sampling_steps = gr.Slider( 1, 50, label="Sampling Steps", value=12, step=1 ) gr.Markdown("Visual Appearance Generation") with gr.Row(): slat_guidance_strength = gr.Slider( 0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1, ) slat_sampling_steps = gr.Slider( 1, 50, label="Sampling Steps", value=12, step=1 ) generate_btn = gr.Button( "Generate(~0.5 mins)", variant="primary", interactive=False ) model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False) with gr.Row(): extract_rep3d_btn = gr.Button( "Extract 3D Representation(~2 mins)", variant="primary", interactive=False, ) with gr.Accordion( label="Enter Asset Attributes(optional)", open=False ): asset_cat_text = gr.Textbox( label="Enter Asset Category (e.g., chair)" ) height_range_text = gr.Textbox( label="Enter Height Range in meter (e.g., 0.5-0.6)" ) mass_range_text = gr.Textbox( label="Enter Mass Range in kg (e.g., 1.1-1.2)" ) asset_version_text = gr.Textbox( label=f"Enter version (e.g., {VERSION})" ) with gr.Row(): extract_urdf_btn = gr.Button( "Extract URDF with physics(~1 mins)", variant="primary", interactive=False, ) with gr.Row(): gr.Markdown( "#### Estimated Asset 3D Attributes(No input required)" ) with gr.Row(): est_type_text = gr.Textbox( label="Asset category", interactive=False ) est_height_text = gr.Textbox( label="Real height(.m)", interactive=False ) est_mass_text = gr.Textbox( label="Mass(.kg)", interactive=False ) est_mu_text = gr.Textbox( label="Friction coefficient", interactive=False ) with gr.Row(): download_urdf = gr.DownloadButton( label="Download URDF", variant="primary", interactive=False ) gr.Markdown( """ NOTE: If `Asset Attributes` are provided, the provided properties will be used; otherwise, the GPT-preset properties will be applied. \n The `Download URDF` file is restored to the real scale and has quality inspection, open with an editor to view details. """ ) with gr.Row() as single_image_example: examples = gr.Examples( label="Image Gallery", examples=[ [f"assets/example_image/{image}"] for image in os.listdir( "assets/example_image" ) ], inputs=[image_prompt], fn=preprocess_image_fn, outputs=[image_prompt, raw_image_cache], run_on_click=True, examples_per_page=10, ) with gr.Row(visible=False) as single_sam_image_example: examples = gr.Examples( label="Image Gallery", examples=[ f"assets/example_image/{image}" for image in os.listdir( "assets/example_image" ) ], inputs=[image_prompt_sam], fn=preprocess_sam_image_fn, outputs=[image_prompt_sam, raw_image_cache], run_on_click=True, examples_per_page=10, ) with gr.Column(scale=1): video_output = gr.Video( label="Generated 3D Asset", autoplay=True, loop=True, height=300, ) model_output_gs = gr.Model3D( label="Gaussian Representation", height=300, interactive=False ) aligned_gs = gr.Textbox(visible=False) lighting_css = """ """ gr.HTML(lighting_css) with gr.Row(): model_output_mesh = gr.Model3D( label="Mesh Representation", height=300, interactive=False, clear_color=[1, 1, 1, 1], elem_id="lighter_mesh" ) gr.Markdown( """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa ) is_samimage = gr.State(False) output_buf = gr.State() selected_points = gr.State(value=[]) demo.load(start_session) demo.unload(end_session) single_image_input_tab.select( lambda: tuple( [False, gr.Row.update(visible=True), gr.Row.update(visible=False)] ), outputs=[is_samimage, single_image_example, single_sam_image_example], ) samimage_input_tab.select( lambda: tuple( [True, gr.Row.update(visible=True), gr.Row.update(visible=False)] ), outputs=[is_samimage, single_sam_image_example, single_image_example], ) image_prompt.upload( preprocess_image_fn, inputs=[image_prompt], outputs=[image_prompt, raw_image_cache], ) image_prompt.change( lambda: tuple( [ gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False), None, "", None, None, "", "", "", "", "", "", "", "", ] ), outputs=[ extract_rep3d_btn, extract_urdf_btn, download_urdf, model_output_gs, aligned_gs, model_output_mesh, video_output, asset_cat_text, height_range_text, mass_range_text, asset_version_text, est_type_text, est_height_text, est_mass_text, est_mu_text, ], ) image_prompt.change( active_btn_by_content, inputs=image_prompt, outputs=generate_btn, ) image_prompt_sam.upload( preprocess_sam_image_fn, inputs=[image_prompt_sam], outputs=[image_prompt_sam, raw_image_cache], ) image_prompt_sam.change( lambda: tuple( [ gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False), None, None, None, "", "", "", "", "", "", "", "", None, [], ] ), outputs=[ extract_rep3d_btn, extract_urdf_btn, download_urdf, model_output_gs, model_output_mesh, video_output, asset_cat_text, height_range_text, mass_range_text, asset_version_text, est_type_text, est_height_text, est_mass_text, est_mu_text, image_mask_sam, selected_points, ], ) image_prompt_sam.select( select_point, [ image_prompt_sam, selected_points, fg_bg_radio, ], [image_mask_sam, image_seg_sam], ) image_seg_sam.change( active_btn_by_content, inputs=image_seg_sam, outputs=generate_btn, ) generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).success( image_to_3d, inputs=[ image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, raw_image_cache, image_seg_sam, is_samimage, ], outputs=[output_buf, video_output], ).success( lambda: gr.Button(interactive=True), outputs=[extract_rep3d_btn], ) extract_rep3d_btn.click( extract_3d_representations_v2, inputs=[ output_buf, project_delight, ], outputs=[ model_output_mesh, model_output_gs, model_output_obj, aligned_gs, ], ).success( lambda: gr.Button(interactive=True), outputs=[extract_urdf_btn], ) extract_urdf_btn.click( extract_urdf, inputs=[ aligned_gs, model_output_obj, asset_cat_text, height_range_text, mass_range_text, asset_version_text, ], outputs=[ download_urdf, est_type_text, est_height_text, est_mass_text, est_mu_text, ], queue=True, show_progress="full", ).success( lambda: gr.Button(interactive=True), outputs=[download_urdf], ) if __name__ == "__main__": demo.launch()