text2image_1 / app.py
RanM's picture
Update app.py
bb032a8 verified
raw
history blame
2.5 kB
import os
import asyncio
from io import BytesIO
from PIL import Image
from diffusers import AutoPipelineForText2Image
import gradio as gr
print("Loading the Stable Diffusion model...")
try:
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
model = None
def generate_image(prompt, prompt_name):
try:
if model is None:
raise ValueError("Model not loaded properly.")
print(f"Generating image for {prompt_name} with prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=50, guidance_scale=7.5)
print(f"Model output for {prompt_name}: {output}")
if output is None:
raise ValueError(f"Model returned None for {prompt_name}")
if hasattr(output, 'images') and output.images:
print(f"Image generated for {prompt_name}")
image = output.images[0]
buffered = BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
return image_bytes
else:
print(f"No images found in model output for {prompt_name}")
raise ValueError(f"No images found in model output for {prompt_name}")
except Exception as e:
print(f"An error occurred while generating image for {prompt_name}: {e}")
return None
def process_prompt(sentence_mapping, character_dict, selected_style):
print("Processing prompt...")
print(f"Sentence Mapping: {sentence_mapping}")
print(f"Character Dict: {character_dict}")
print(f"Selected Style: {selected_style}")
prompt_results = {}
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
prompt = f"Make an illustration in {selected_style} style from: {combined_sentence}"
image_bytes = generate_image(prompt, f"Prompt {paragraph_number}")
prompt_results[paragraph_number] = image_bytes
return prompt_results
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json"
).queue(concurrency_limit=10)
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()