from gradio_client import Client, handle_file import gradio as gr # Initialize Gradio Client client = Client("JeffreyXiang/TRELLIS") # Replace with your Hugging Face Space # Helper Functions for API Calls def start_session(): result = client.predict(api_name="/start_session") return result def preprocess_image(image): result = client.predict( image=handle_file(image), api_name="/preprocess_image" ) return result def preprocess_images(images): processed_images = [ {"image": handle_file(img), "caption": None} for img in images ] result = client.predict( images=processed_images, api_name="/preprocess_images" ) return result def get_seed(randomize_seed, seed): result = client.predict( randomize_seed=randomize_seed, seed=seed, api_name="/get_seed" ) return result def image_to_3d(image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo): result = client.predict( image=handle_file(image), multiimages=[], seed=seed, ss_guidance_strength=ss_guidance_strength, ss_sampling_steps=ss_sampling_steps, slat_guidance_strength=slat_guidance_strength, slat_sampling_steps=slat_sampling_steps, multiimage_algo=multiimage_algo, api_name="/image_to_3d" ) return result["video"] def extract_glb(mesh_simplify, texture_size): result = client.predict( mesh_simplify=mesh_simplify, texture_size=texture_size, api_name="/extract_glb" ) return result[1] # Return the GLB file path for download def extract_gaussian(): result = client.predict(api_name="/extract_gaussian") return result[1] # Return the Gaussian file path for download # Define Gradio UI with gr.Blocks() as demo: gr.Markdown("# Image to 3D Model with TRELLIS API") with gr.Row(): with gr.Column(): image = gr.Image(type="filepath", label="Upload Image") seed = gr.Slider(0, 100, value=0, step=1, label="Seed") ss_guidance_strength = gr.Slider(0.0, 10.0, value=7.5, step=0.1, label="SS Guidance Strength") ss_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SS Sampling Steps") slat_guidance_strength = gr.Slider(0.0, 10.0, value=3.0, step=0.1, label="SLAT Guidance Strength") slat_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SLAT Sampling Steps") multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], value="stochastic", label="Multi-image Algorithm") generate_btn = gr.Button("Generate 3D Model") with gr.Column(): video_output = gr.Video(label="3D Model Preview") download_glb_btn = gr.Button("Download GLB") download_gaussian_btn = gr.Button("Download Gaussian") glb_file = gr.File(label="GLB File") gaussian_file = gr.File(label="Gaussian File") # Define Actions generate_btn.click( fn=image_to_3d, inputs=[image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo], outputs=[video_output] ) # Download Buttons and Actions download_glb_btn.click( fn=lambda: extract_glb(mesh_simplify=0.95, texture_size=1024), # Static values inputs=[], # No dynamic inputs outputs=[glb_file] # Output file for GLB ) download_gaussian_btn.click( fn=extract_gaussian, # Direct function call inputs=[], # No dynamic inputs outputs=[gaussian_file] # Output file for Gaussian ) # Launch Gradio App if __name__ == "__main__": demo.launch(show_error=True)