Spaces:
Runtime error
Runtime error
File size: 3,849 Bytes
c513221 109adde 441106f 9da79fd 5e2c7ed 5d9bf5a b85438c 5d9bf5a 86743ba 5d9bf5a 28413d5 5d9bf5a 630a72e 28413d5 441106f 5d9bf5a 630a72e 5d9bf5a 28413d5 5d9bf5a 28413d5 5d9bf5a 3b7350e 834f7ba 28413d5 690f094 d253f4a 690f094 bdf16c0 834f7ba 28413d5 834f7ba 30c04e8 575d097 cfeca25 690f094 081cd9c 5e2c7ed 28413d5 109adde 834f7ba 109adde 081cd9c 690f094 bdf16c0 28413d5 bdf16c0 f466dd9 630a72e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import os
import asyncio
import time
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr
import ray
ray.init()
@ray.remote
class ModelActor:
def __init__(self):
"""
Initializes the ModelActor class and loads the text-to-image model.
"""
self.model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
async def generate_image(self, prompt, prompt_name):
"""
Generates an image based on the provided prompt.
Parameters:
- prompt (str): The input text for image generation.
- prompt_name (str): A name for the prompt, used for logging.
Returns:
bytes: The generated image data in bytes format, or None if generation fails.
"""
start_time = time.time()
process_id = os.getpid()
try:
output = await self.model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
if isinstance(output.images, list) and len(output.images) > 0:
image = output.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
end_time = time.time()
return image_bytes
else:
return None
except Exception as e:
return None
async def queue_api_calls(sentence_mapping, character_dict, selected_style):
"""
Generates images for all provided prompts in parallel using Ray actors.
Parameters:
- sentence_mapping (dict): Mapping between paragraph numbers and sentences.
- character_dict (dict): Dictionary mapping characters to their descriptions.
- selected_style (str): Selected illustration style.
Returns:
dict: A dictionary where keys are paragraph numbers and values are image data in bytes format.
"""
prompts = []
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
prompts.append((paragraph_number, prompt))
num_prompts = len(prompts)
num_actors = min(num_prompts, 20)
model_actors = [ModelActor.remote() for _ in range(num_actors)]
tasks = [model_actors[i % num_actors].generate_image.remote(prompt, f"Prompt {paragraph_number}") for i, (paragraph_number, prompt) in enumerate(prompts)]
responses = await asyncio.gather(*[ray.get(task) for task in tasks])
images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
return images
def process_prompt(sentence_mapping, character_dict, selected_style):
"""
Processes the provided prompts and generates images.
Parameters:
- sentence_mapping (dict): Mapping between paragraph numbers and sentences.
- character_dict (dict): Dictionary mapping characters to their descriptions.
- selected_style (str): Selected illustration style.
Returns:
dict: A dictionary where keys are paragraph numbers and values are image data in bytes format.
"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
return cmpt_return
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[gr.JSON(label="Sentence Mapping"), gr.JSON(label="Character Dict"), gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")],
outputs="json"
)
if __name__ == "__main__":
gradio_interface.launch()
|