TheSquidBaron commited on
Commit
a86d422
·
verified ·
1 Parent(s): 6cf650b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -140
app.py CHANGED
@@ -1,141 +1,61 @@
1
- import argparse
2
  import gradio as gr
3
- import os
4
- import torch
5
- import trimesh
6
- import sys
7
- from pathlib import Path
8
-
9
- pathdir = Path(__file__).parent / 'cube'
10
- sys.path.append(pathdir.as_posix())
11
-
12
- # print(__file__)
13
- # print(os.listdir())
14
- # print(os.listdir('cube'))
15
- # print(pathdir.as_posix())
16
-
17
- from cube3d.inference.engine import EngineFast, Engine
18
- from pathlib import Path
19
- import uuid
20
- import shutil
21
- from huggingface_hub import snapshot_download
22
-
23
-
24
- GLOBAL_STATE = {}
25
-
26
- def gen_save_folder(max_size=200):
27
- os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
28
-
29
- dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]
30
-
31
- if len(dirs) >= max_size:
32
- oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
33
- shutil.rmtree(oldest_dir)
34
- print(f"Removed the oldest folder: {oldest_dir}")
35
-
36
- new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
37
- os.makedirs(new_folder, exist_ok=True)
38
- print(f"Created new folder: {new_folder}")
39
-
40
- return new_folder
41
-
42
- def handle_text_prompt(input_prompt, variance = 0):
43
- print(f"prompt: {input_prompt}, variance: {variance}")
44
- top_p = None if variance == 0 else (100 - variance) / 100.0
45
- mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p)
46
- # save output
47
- vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
48
- save_folder = gen_save_folder()
49
- output_path = os.path.join(save_folder, "output.glb")
50
- trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
51
- return output_path
52
-
53
- def build_interface():
54
- """Build UI for gradio app
55
- """
56
- title = "Cube 3D"
57
- with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
58
- gr.Markdown(
59
- f"""
60
- # {title}
61
- # Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine!
62
- """
63
- )
64
-
65
- with gr.Row():
66
- with gr.Column(scale=2):
67
- with gr.Group():
68
- input_text_box = gr.Textbox(
69
- value=None,
70
- label="Prompt",
71
- lines=2,
72
- )
73
- variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance")
74
- with gr.Row():
75
- submit_button = gr.Button("Submit", variant="primary")
76
- with gr.Column(scale=3):
77
- model3d = gr.Model3D(
78
- label="Output", height="45em", interactive=False
79
- )
80
-
81
- submit_button.click(
82
- handle_text_prompt,
83
- inputs=[
84
- input_text_box,
85
- variance
86
- ],
87
- outputs=[
88
- model3d
89
- ]
90
- )
91
-
92
- return interface
93
-
94
- if __name__=="__main__":
95
-
96
- parser = argparse.ArgumentParser()
97
- parser.add_argument(
98
- "--config_path",
99
- type=str,
100
- help="Path to the config file",
101
- default="cube/cube3d/configs/open_model.yaml",
102
- )
103
- parser.add_argument(
104
- "--gpt_ckpt_path",
105
- type=str,
106
- help="Path to the gpt ckpt path",
107
- default="model_weights/shape_gpt.safetensors",
108
- )
109
- parser.add_argument(
110
- "--shape_ckpt_path",
111
- type=str,
112
- help="Path to the shape ckpt path",
113
- default="model_weights/shape_tokenizer.safetensors",
114
- )
115
- parser.add_argument(
116
- "--save_dir",
117
- type=str,
118
- default="gradio_save_dir",
119
- )
120
-
121
- args = parser.parse_args()
122
- snapshot_download(
123
- repo_id="Roblox/cube3d-v0.1",
124
- local_dir="./model_weights"
125
- )
126
- config_path = args.config_path
127
- gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
128
- shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
129
- engine_fast = EngineFast(
130
- config_path,
131
- gpt_ckpt_path,
132
- shape_ckpt_path,
133
- device=torch.device("cuda"),
134
- )
135
- GLOBAL_STATE["engine_fast"] = engine_fast
136
- GLOBAL_STATE["SAVE_DIR"] = args.save_dir
137
- os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
138
-
139
- demo = build_interface()
140
- demo.queue(default_concurrency_limit=1)
141
- demo.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import svgwrite
4
+ import cairosvg
5
+ import speech_recognition as sr
6
+ import io
7
+
8
+ # Load the StarVector model
9
+ tokenizer = AutoTokenizer.from_pretrained("starvector/starvector-8b-im2svg")
10
+ model = AutoModelForCausalLM.from_pretrained("starvector/starvector-8b-im2svg")
11
+
12
+ def generate_svg(prompt, width, height):
13
+ inputs = tokenizer(prompt, return_tensors="pt")
14
+ outputs = model.generate(**inputs, max_length=512)
15
+ svg_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
16
+
17
+ # Ensure SVG is properly wrapped
18
+ svg_wrapped = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">{svg_code}</svg>'
19
+
20
+ # Convert to PNG
21
+ png_output = cairosvg.svg2png(bytestring=svg_wrapped.encode('utf-8'))
22
+
23
+ with open("output.svg", "w") as f:
24
+ f.write(svg_wrapped)
25
+ with open("output.png", "wb") as f:
26
+ f.write(png_output)
27
+
28
+ return svg_wrapped, "output.png", "output.svg"
29
+
30
+ def transcribe_audio(audio_path):
31
+ recognizer = sr.Recognizer()
32
+ with sr.AudioFile(audio_path) as source:
33
+ audio_data = recognizer.record(source)
34
+ return recognizer.recognize_google(audio_data)
35
+
36
+ with gr.Blocks() as demo:
37
+ gr.Markdown("## Vector Logo Generator (Text + Voice)")
38
+
39
+ with gr.Row():
40
+ txt = gr.Textbox(label="Text Prompt")
41
+ mic = gr.Audio(source="microphone", type="filepath", label="Or speak your prompt")
42
+
43
+ with gr.Row():
44
+ width = gr.Slider(minimum=100, maximum=1000, value=500, step=10, label="Width (px)")
45
+ height = gr.Slider(minimum=100, maximum=1000, value=500, step=10, label="Height (px)")
46
+
47
+ svg_output = gr.Textbox(label="SVG Code Output")
48
+ png_output = gr.Image(label="PNG Preview")
49
+ svg_file = gr.File(label="Download SVG")
50
+ png_file = gr.File(label="Download PNG")
51
+
52
+ def run(prompt, audio, w, h):
53
+ if not prompt and audio:
54
+ prompt = transcribe_audio(audio)
55
+ svg, png_path, svg_path = generate_svg(prompt, w, h)
56
+ return svg, png_path, svg_path
57
+
58
+ run_button = gr.Button("Generate")
59
+ run_button.click(fn=run, inputs=[txt, mic, width, height], outputs=[svg_output, png_output, svg_file, png_file])
60
+
61
+ demo.launch()