Spaces:
Runtime error
Runtime error
File size: 3,696 Bytes
c513221 109adde 5c3986b 5e2c7ed 6b1b953 6706e0b 5e2c7ed eb48f29 86743ba 5c3986b 6b1b953 5d9bf5a 33d78b0 eb48f29 6b1b953 efca12e 6b1b953 efca12e 6b1b953 6292767 5c3986b eb48f29 194be56 6292767 eb48f29 3b7350e 834f7ba 5c3986b 28413d5 690f094 d253f4a 690f094 5c3986b fd77b23 33d78b0 5c3986b 33d78b0 9cd3a95 cfeca25 efca12e 690f094 081cd9c 5e2c7ed 5c3986b 109adde 5c3986b 109adde 5c3986b 109adde 834f7ba efca12e 109adde 081cd9c 690f094 bdf16c0 5c3986b bdf16c0 eb48f29 bdf16c0 f466dd9 5c3986b 630a72e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import asyncio
from concurrent.futures import ProcessPoolExecutor
from io import BytesIO
from PIL import Image
from diffusers import StableDiffusionPipeline
import gradio as gr
from generate_prompts import generate_prompt
# Load the model once at the start
print("Loading the Stable Diffusion model...")
try:
model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
model = None
def generate_image(prompt, prompt_name):
try:
if model is None:
raise ValueError("Model not loaded properly.")
print(f"Generating image for {prompt_name} with prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Model output for {prompt_name}: {output}")
if output is None:
raise ValueError(f"Model returned None for {prompt_name}")
if hasattr(output, 'images') and output.images:
print(f"Image generated for {prompt_name}")
image = output.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
return image_bytes
else:
print(f"No images found in model output for {prompt_name}")
raise ValueError(f"No images found in model output for {prompt_name}")
except Exception as e:
print(f"An error occurred while generating image for {prompt_name}: {e}")
return None
async def queue_api_calls(sentence_mapping, character_dict, selected_style):
print("Starting to queue API calls...")
prompts = []
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
prompts.append((paragraph_number, prompt))
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
loop = asyncio.get_running_loop()
with ProcessPoolExecutor() as pool:
tasks = [
loop.run_in_executor(pool, generate_image, prompt, f"Prompt {paragraph_number}")
for paragraph_number, prompt in prompts
]
responses = await asyncio.gather(*tasks)
images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
print("Finished queuing API calls. Generated images: ", images)
return images
def process_prompt(sentence_mapping, character_dict, selected_style):
print("Processing prompt...")
print(f"Sentence Mapping: {sentence_mapping}")
print(f"Character Dict: {character_dict}")
print(f"Selected Style: {selected_style}")
try:
loop = asyncio.get_running_loop()
print("Using existing event loop.")
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
print("Created new event loop.")
cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
print("Prompt processing complete. Generated images: ", cmpt_return)
return cmpt_return
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json"
).queue(default_concurrency_limit=20) # Set concurrency limit if needed
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()
|