text2image_1 / app.py
RanM's picture
Update app.py
efca12e verified
raw
history blame
3.36 kB
import os
import asyncio
from concurrent.futures import ProcessPoolExecutor
from io import BytesIO
from diffusers import StableDiffusionPipeline
import gradio as gr
from generate_prompts import generate_prompt
# Load the model once at the start
print("Loading the Stable Diffusion model...")
model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")
def generate_image(prompt, prompt_name):
try:
print(f"Generating image for {prompt_name} with prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Model output for {prompt_name}: {output}")
if output and hasattr(output, 'images') and output.images:
print(f"Image generated for {prompt_name}")
image = output.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
return image_bytes
else:
print(f"No images found or generated output is None for {prompt_name}")
return None
except Exception as e:
print(f"An error occurred while generating image for {prompt_name}: {e}")
return None
async def queue_api_calls(sentence_mapping, character_dict, selected_style):
print("Starting to queue API calls...")
prompts = []
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
prompts.append((paragraph_number, prompt))
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
loop = asyncio.get_running_loop()
with ProcessPoolExecutor() as pool:
tasks = [
loop.run_in_executor(pool, generate_image, prompt, f"Prompt {paragraph_number}")
for paragraph_number, prompt in prompts
]
responses = await asyncio.gather(*tasks)
images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
print("Finished queuing API calls. Generated images: ", images)
return images
def process_prompt(sentence_mapping, character_dict, selected_style):
print("Processing prompt...")
print(f"Sentence Mapping: {sentence_mapping}")
print(f"Character Dict: {character_dict}")
print(f"Selected Style: {selected_style}")
try:
loop = asyncio.get_running_loop()
print("Using existing event loop.")
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
print("Created new event loop.")
cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
print("Prompt processing complete. Generated images: ", cmpt_return)
return cmpt_return
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json"
).queue(default_concurrency_limit=20) # Set concurrency limit if needed
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()