Spaces:
Running
on
L40S
Running
on
L40S
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import svgwrite | |
import cairosvg | |
import speech_recognition as sr | |
import io | |
# Load the StarVector model | |
tokenizer = AutoTokenizer.from_pretrained("starvector/starvector-8b-im2svg") | |
model = AutoModelForCausalLM.from_pretrained("starvector/starvector-8b-im2svg") | |
def generate_svg(prompt, width, height): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(**inputs, max_length=512) | |
svg_code = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Ensure SVG is properly wrapped | |
svg_wrapped = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">{svg_code}</svg>' | |
# Convert to PNG | |
png_output = cairosvg.svg2png(bytestring=svg_wrapped.encode('utf-8')) | |
with open("output.svg", "w") as f: | |
f.write(svg_wrapped) | |
with open("output.png", "wb") as f: | |
f.write(png_output) | |
return svg_wrapped, "output.png", "output.svg" | |
def transcribe_audio(audio_path): | |
recognizer = sr.Recognizer() | |
with sr.AudioFile(audio_path) as source: | |
audio_data = recognizer.record(source) | |
return recognizer.recognize_google(audio_data) | |
with gr.Blocks() as demo: | |
gr.Markdown("## Vector Logo Generator (Text + Voice)") | |
with gr.Row(): | |
txt = gr.Textbox(label="Text Prompt") | |
mic = gr.Audio(source="microphone", type="filepath", label="Or speak your prompt") | |
with gr.Row(): | |
width = gr.Slider(minimum=100, maximum=1000, value=500, step=10, label="Width (px)") | |
height = gr.Slider(minimum=100, maximum=1000, value=500, step=10, label="Height (px)") | |
svg_output = gr.Textbox(label="SVG Code Output") | |
png_output = gr.Image(label="PNG Preview") | |
svg_file = gr.File(label="Download SVG") | |
png_file = gr.File(label="Download PNG") | |
def run(prompt, audio, w, h): | |
if not prompt and audio: | |
prompt = transcribe_audio(audio) | |
svg, png_path, svg_path = generate_svg(prompt, w, h) | |
return svg, png_path, svg_path | |
run_button = gr.Button("Generate") | |
run_button.click(fn=run, inputs=[txt, mic, width, height], outputs=[svg_output, png_output, svg_file, png_file]) | |
demo.launch() | |