File size: 3,802 Bytes
60ebd70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0767aa
60ebd70
b0767aa
 
 
60ebd70
 
 
b0767aa
 
 
60ebd70
 
 
 
3613dc8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from gradio_client import Client, handle_file
import gradio as gr

# Initialize Gradio Client
client = Client("JeffreyXiang/TRELLIS")  # Replace with your Hugging Face Space

# Helper Functions for API Calls
def start_session():
    result = client.predict(api_name="/start_session")
    return result

def preprocess_image(image):
    result = client.predict(
        image=handle_file(image),
        api_name="/preprocess_image"
    )
    return result

def preprocess_images(images):
    processed_images = [
        {"image": handle_file(img), "caption": None} for img in images
    ]
    result = client.predict(
        images=processed_images,
        api_name="/preprocess_images"
    )
    return result

def get_seed(randomize_seed, seed):
    result = client.predict(
        randomize_seed=randomize_seed,
        seed=seed,
        api_name="/get_seed"
    )
    return result

def image_to_3d(image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo):
    result = client.predict(
        image=handle_file(image),
        multiimages=[],
        seed=seed,
        ss_guidance_strength=ss_guidance_strength,
        ss_sampling_steps=ss_sampling_steps,
        slat_guidance_strength=slat_guidance_strength,
        slat_sampling_steps=slat_sampling_steps,
        multiimage_algo=multiimage_algo,
        api_name="/image_to_3d"
    )
    return result["video"]

def extract_glb(mesh_simplify, texture_size):
    result = client.predict(
        mesh_simplify=mesh_simplify,
        texture_size=texture_size,
        api_name="/extract_glb"
    )
    return result[1]  # Return the GLB file path for download

def extract_gaussian():
    result = client.predict(api_name="/extract_gaussian")
    return result[1]  # Return the Gaussian file path for download

# Define Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Image to 3D Model with TRELLIS API")

    with gr.Row():
        with gr.Column():
            image = gr.Image(type="filepath", label="Upload Image")
            seed = gr.Slider(0, 100, value=0, step=1, label="Seed")
            ss_guidance_strength = gr.Slider(0.0, 10.0, value=7.5, step=0.1, label="SS Guidance Strength")
            ss_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SS Sampling Steps")
            slat_guidance_strength = gr.Slider(0.0, 10.0, value=3.0, step=0.1, label="SLAT Guidance Strength")
            slat_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SLAT Sampling Steps")
            multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], value="stochastic", label="Multi-image Algorithm")
            generate_btn = gr.Button("Generate 3D Model")
        
        with gr.Column():
            video_output = gr.Video(label="3D Model Preview")
            download_glb_btn = gr.Button("Download GLB")
            download_gaussian_btn = gr.Button("Download Gaussian")
            glb_file = gr.File(label="GLB File")
            gaussian_file = gr.File(label="Gaussian File")

    # Define Actions
    generate_btn.click(
        fn=image_to_3d,
        inputs=[image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
        outputs=[video_output]
    )

    # Download Buttons and Actions
    download_glb_btn.click(
        fn=lambda: extract_glb(mesh_simplify=0.95, texture_size=1024),  # Static values
        inputs=[],  # No dynamic inputs
        outputs=[glb_file]  # Output file for GLB
    )

    download_gaussian_btn.click(
        fn=extract_gaussian,  # Direct function call
        inputs=[],  # No dynamic inputs
        outputs=[gaussian_file]  # Output file for Gaussian
    )

# Launch Gradio App
if __name__ == "__main__":
    demo.launch(show_error=True)