Spaces:
Runtime error
Runtime error
File size: 3,721 Bytes
c513221 86743ba 9da79fd 5e2c7ed 86743ba b85438c 86743ba 3b7350e 86743ba 3b7350e 86743ba 3b7350e 86743ba d26a101 86743ba 690f094 c7f120b e284958 690f094 bdf16c0 86743ba c7f120b 86743ba c7f120b 690f094 081cd9c 5e2c7ed a04441d 86743ba 081cd9c 690f094 bdf16c0 690f094 bdf16c0 f466dd9 c7f120b bdf16c0 c7f120b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
import multiprocessing
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr
import json
# Define a global variable to hold the model
model = None
def initialize_model():
global model
if model is None: # Ensure the model is loaded only once per process
print("Loading the model...")
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")
def generate_image(prompt, prompt_name):
try:
print(f"Generating response for {prompt_name} with prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Output for {prompt_name}: {output}")
# Check if the model returned images
if isinstance(output.images, list) and len(output.images) > 0:
image = output.images[0]
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
return prompt_name, image_bytes
except Exception as e:
print(f"Error saving image for {prompt_name}: {e}")
return prompt_name, None
else:
raise Exception(f"No images returned by the model for {prompt_name}.")
except Exception as e:
print(f"Error generating image for {prompt_name}: {e}")
return prompt_name, None
def process_prompts(sentence_mapping, character_dict, selected_style):
print(f"process_prompts called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
prompts = []
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
prompts.append((paragraph_number, prompt))
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
num_prompts = len(prompts)
print(f"Number of prompts: {num_prompts}")
# Limit the number of worker processes to the number of prompts
with multiprocessing.Pool(processes=num_prompts, initializer=initialize_model) as pool:
tasks = [(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
results = pool.starmap(generate_image, tasks)
images = {prompt_name: image for prompt_name, image in results}
print(f"Images generated: {images}")
return images
def process_prompt(sentence_mapping, character_dict, selected_style):
print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
# Check if inputs are already in dict form
if isinstance(sentence_mapping, str):
sentence_mapping = json.loads(sentence_mapping)
if isinstance(character_dict, str):
character_dict = json.loads(character_dict)
return process_prompts(sentence_mapping, character_dict, selected_style)
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json"
)
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()
print("Gradio interface launched.")
|