Spaces:
Runtime error
Runtime error
File size: 1,994 Bytes
f466dd9 a9b8939 9e09422 cd715cb b5ad13a a9b8939 6d1d03a 8a0f059 f466dd9 b5ad13a a9b8939 8a0f059 b5ad13a 8a0f059 b5ad13a f466dd9 b5ad13a 8a0f059 f466dd9 b5ad13a 8a0f059 b5ad13a 8a0f059 b5ad13a ca1d41c f466dd9 b5ad13a f466dd9 a9b8939 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
import base64
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor
# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
def generate_image(text, sentence_mapping, character_dict, selected_style):
try:
prompt, _ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
image = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
except Exception as e:
print(f"Error generating image: {e}")
return None
def inference(text, sentence_mapping, character_dict, selected_style):
images = {}
# Here we assume `sentence_mapping` is a dictionary where keys are paragraph numbers and values are lists of sentences
grouped_sentences = sentence_mapping
with ThreadPoolExecutor() as executor:
futures = {}
for paragraph_number, sentences in grouped_sentences.items():
combined_sentence = " ".join(sentences)
futures[paragraph_number] = executor.submit(generate_image, combined_sentence, sentence_mapping, character_dict, selected_style)
for paragraph_number, future in futures.items():
images[paragraph_number] = future.result()
return images
gradio_interface = gr.Interface(
fn=inference,
inputs=[
gr.inputs.Textbox(label="Text"),
gr.inputs.Textbox(label="Sentence Mapping"),
gr.inputs.Textbox(label="Character Dict"),
gr.inputs.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
],
outputs="json" # Return the dictionary of images
)
if __name__ == "__main__":
gradio_interface.launch()
|