Spaces:
Runtime error
Runtime error
File size: 2,631 Bytes
f6b8b7e 789e6b5 13298a2 789e6b5 216a041 b85438c 789e6b5 66e43e3 13298a2 216a041 f466dd9 789e6b5 c301a62 6449f8f 789e6b5 f466dd9 b5ad13a 789e6b5 d26a101 789e6b5 b0bcf89 e0ec116 b0bcf89 e0ec116 6449f8f 789e6b5 b0bcf89 789e6b5 b0bcf89 6035350 f466dd9 789e6b5 d05fa5e 1adc78a e0ec116 d05fa5e 789e6b5 58f74fc f466dd9 a9b8939 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor
import asyncio
# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
async def generate_image(prompt):
try:
# Truncate prompt if necessary
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Model output: {output}")
# Check if the model returned images
if isinstance(output.images, list) and len(output.images) > 0:
image = output.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
return image_bytes
else:
raise Exception("No images returned by the model.")
except Exception as e:
print(f"Error generating image: {e}")
return None
async def process_prompt(sentence_mapping, character_dict, selected_style):
images = {}
print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
prompts = []
# Generate prompts for each paragraph
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
prompts.append((paragraph_number, prompt))
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
loop = asyncio.get_event_loop()
tasks = []
with ThreadPoolExecutor() as executor:
for paragraph_number, prompt in prompts:
tasks.append(loop.run_in_executor(executor, generate_image, prompt))
for paragraph_number, task in zip(sentence_mapping.keys(), await asyncio.gather(*tasks)):
try:
image = task
if image:
images[paragraph_number] = image
except Exception as e:
print(f"Error processing paragraph {paragraph_number}: {e}")
return images
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json",
concurrency_limit=10 # Allow up to 10 concurrent executions
)
if __name__ == "__main__":
gradio_interface.launch()
|