theekshana commited on
Commit
60ebd70
·
1 Parent(s): c771b2b
Files changed (2) hide show
  1. app.py +106 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client, handle_file
2
+ import gradio as gr
3
+
4
+ # Initialize Gradio Client
5
+ client = Client("JeffreyXiang/TRELLIS") # Replace with your Hugging Face Space
6
+
7
+ # Helper Functions for API Calls
8
+ def start_session():
9
+ result = client.predict(api_name="/start_session")
10
+ return result
11
+
12
+ def preprocess_image(image):
13
+ result = client.predict(
14
+ image=handle_file(image),
15
+ api_name="/preprocess_image"
16
+ )
17
+ return result
18
+
19
+ def preprocess_images(images):
20
+ processed_images = [
21
+ {"image": handle_file(img), "caption": None} for img in images
22
+ ]
23
+ result = client.predict(
24
+ images=processed_images,
25
+ api_name="/preprocess_images"
26
+ )
27
+ return result
28
+
29
+ def get_seed(randomize_seed, seed):
30
+ result = client.predict(
31
+ randomize_seed=randomize_seed,
32
+ seed=seed,
33
+ api_name="/get_seed"
34
+ )
35
+ return result
36
+
37
+ def image_to_3d(image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo):
38
+ result = client.predict(
39
+ image=handle_file(image),
40
+ multiimages=[],
41
+ seed=seed,
42
+ ss_guidance_strength=ss_guidance_strength,
43
+ ss_sampling_steps=ss_sampling_steps,
44
+ slat_guidance_strength=slat_guidance_strength,
45
+ slat_sampling_steps=slat_sampling_steps,
46
+ multiimage_algo=multiimage_algo,
47
+ api_name="/image_to_3d"
48
+ )
49
+ return result["video"]
50
+
51
+ def extract_glb(mesh_simplify, texture_size):
52
+ result = client.predict(
53
+ mesh_simplify=mesh_simplify,
54
+ texture_size=texture_size,
55
+ api_name="/extract_glb"
56
+ )
57
+ return result[1] # Return the GLB file path for download
58
+
59
+ def extract_gaussian():
60
+ result = client.predict(api_name="/extract_gaussian")
61
+ return result[1] # Return the Gaussian file path for download
62
+
63
+ # Define Gradio UI
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("# Image to 3D Model with TRELLIS API")
66
+
67
+ with gr.Row():
68
+ with gr.Column():
69
+ image = gr.Image(type="filepath", label="Upload Image")
70
+ seed = gr.Slider(0, 100, value=0, step=1, label="Seed")
71
+ ss_guidance_strength = gr.Slider(0.0, 10.0, value=7.5, step=0.1, label="SS Guidance Strength")
72
+ ss_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SS Sampling Steps")
73
+ slat_guidance_strength = gr.Slider(0.0, 10.0, value=3.0, step=0.1, label="SLAT Guidance Strength")
74
+ slat_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SLAT Sampling Steps")
75
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], value="stochastic", label="Multi-image Algorithm")
76
+ generate_btn = gr.Button("Generate 3D Model")
77
+
78
+ with gr.Column():
79
+ video_output = gr.Video(label="3D Model Preview")
80
+ download_glb_btn = gr.Button("Download GLB")
81
+ download_gaussian_btn = gr.Button("Download Gaussian")
82
+ glb_file = gr.File(label="GLB File")
83
+ gaussian_file = gr.File(label="Gaussian File")
84
+
85
+ # Define Actions
86
+ generate_btn.click(
87
+ fn=image_to_3d,
88
+ inputs=[image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
89
+ outputs=[video_output]
90
+ )
91
+
92
+ download_glb_btn.click(
93
+ fn=extract_glb,
94
+ inputs=[0.95, 1024], # Example: default values for mesh_simplify and texture_size
95
+ outputs=[glb_file]
96
+ )
97
+
98
+ download_gaussian_btn.click(
99
+ fn=extract_gaussian,
100
+ inputs=[],
101
+ outputs=[gaussian_file]
102
+ )
103
+
104
+ # Launch Gradio App
105
+ if __name__ == "__main__":
106
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio_client