Spaces:
Runtime error
Runtime error
File size: 3,361 Bytes
5e2c7ed 690f094 9da79fd 5e2c7ed b85438c 0cdadc9 6dc4bbd 13298a2 8551936 f466dd9 5e2c7ed 8551936 690f094 f466dd9 c9b9787 789e6b5 d26a101 5e2c7ed 690f094 bdf16c0 690f094 5abcbb9 8551936 690f094 bdf16c0 8551936 bdf16c0 5e2c7ed bdf16c0 690f094 081cd9c 5e2c7ed bdf16c0 5e2c7ed bdf16c0 5e2c7ed 081cd9c 690f094 bdf16c0 690f094 bdf16c0 f466dd9 bdf16c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
import asyncio
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr
# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
def generate_image(prompt, prompt_name):
try:
print(f"Generating response for {prompt_name}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
# Check if the model returned images
if isinstance(output.images, list) and len(output.images) > 0:
image = output.images[0]
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
return image_bytes
except Exception as e:
print(f"Error saving image for {prompt_name}: {e}")
return None
else:
raise Exception(f"No images returned by the model for {prompt_name}.")
except Exception as e:
print(f"Error generating image for {prompt_name}: {e}")
return None
async def queue_api_calls(sentence_mapping, character_dict, selected_style):
print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
prompts = []
# Generate prompts for each paragraph
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
print(f'combined_sentence: {combined_sentence}, character_dict: {character_dict}, selected_style: {selected_style}')
prompt = generate_prompt(combined_sentence, character_dict, selected_style)
prompts.append((paragraph_number, prompt))
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
# Generate images for each prompt in parallel
loop = asyncio.get_running_loop()
tasks = [loop.run_in_executor(None, generate_image, prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
responses = await asyncio.gather(*tasks)
images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
return images
def process_prompt(sentence_mapping, character_dict, selected_style):
try:
# See if there is a loop already running. If there is, reuse it.
loop = asyncio.get_running_loop()
except RuntimeError:
# Create new event loop if one is not running
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# This sends the prompts to function that sets up the async calls. Once all the calls to the API complete, it returns a list of the gr.Textbox with value= set.
cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
return cmpt_return
# Gradio interface with high concurrency limit
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json"
)
if __name__ == "__main__":
gradio_interface.launch()
|