Spaces:
Runtime error
Runtime error
import asyncio | |
import json | |
import gradio as gr | |
from diffusers import AutoPipelineForText2Image | |
from generate_prompts import generate_prompt | |
# Load the model once outside of the function | |
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo") | |
async def generate_image(prompt, prompt_name): | |
try: | |
print(f"Generating image for {prompt_name}") | |
output = await model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0) | |
image = output.images[0] | |
img_bytes = image.tobytes() | |
print(f"Image bytes length for {prompt_name}: {len(img_bytes)}") | |
return img_bytes | |
except Exception as e: | |
print(f"Error generating image for {prompt_name}: {e}") | |
return None | |
async def queue_image_calls(prompts): | |
tasks = [generate_image(prompts[i], f"Prompt {i}") for i in range(len(prompts))] | |
responses = await asyncio.gather(*tasks) | |
return responses | |
def async_image_generation(prompts): | |
try: | |
loop = asyncio.get_running_loop() | |
except RuntimeError: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
results = loop.run_until_complete(queue_image_calls(prompts)) | |
return results | |
def gradio_interface(sentence_mapping, character_dict, selected_style): | |
prompts = generate_prompt(sentence_mapping, character_dict, selected_style) | |
image_bytes_list = async_image_generation(prompts) | |
outputs = [gr.Image.update(value=img_bytes) if img_bytes else gr.Image.update(value=None) for img_bytes in image_bytes_list] | |
return outputs | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
sentence_mapping_input = gr.Textbox(label="Sentence Mapping") | |
character_dict_input = gr.Textbox(label="Character Dictionary") | |
selected_style_input = gr.Textbox(label="Selected Style") | |
submit_btn = gr.Button(value='Submit') | |
prompt_responses = [] # Empty list for dynamic addition of Image components | |
demo.load(fn=lambda x: x, inputs=[], outputs=prompt_responses) | |
submit_btn.click(fn=gradio_interface, | |
inputs=[sentence_mapping_input, character_dict_input, selected_style_input], | |
outputs=prompt_responses) | |
if __name__ == "__main__": | |
demo.launch() | |