File size: 3,849 Bytes
c513221
109adde
441106f
9da79fd
5e2c7ed
 
 
5d9bf5a
b85438c
5d9bf5a
86743ba
5d9bf5a
 
 
28413d5
 
 
5d9bf5a
 
630a72e
28413d5
 
 
 
 
 
 
 
441106f
 
5d9bf5a
630a72e
5d9bf5a
 
 
28413d5
 
 
 
5d9bf5a
28413d5
5d9bf5a
 
3b7350e
834f7ba
28413d5
 
 
 
 
 
 
 
 
 
690f094
 
d253f4a
690f094
bdf16c0
834f7ba
28413d5
834f7ba
 
30c04e8
575d097
cfeca25
690f094
081cd9c
5e2c7ed
28413d5
 
 
 
 
 
 
 
 
109adde
 
 
 
 
 
834f7ba
109adde
081cd9c
690f094
bdf16c0
28413d5
bdf16c0
 
 
f466dd9
630a72e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import asyncio
import time
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr
import ray

ray.init()

@ray.remote
class ModelActor:
    def __init__(self):
        """
        Initializes the ModelActor class and loads the text-to-image model.
        """
        self.model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

    async def generate_image(self, prompt, prompt_name):
        """
        Generates an image based on the provided prompt.
        Parameters:
            - prompt (str): The input text for image generation.
            - prompt_name (str): A name for the prompt, used for logging.
        Returns:
            bytes: The generated image data in bytes format, or None if generation fails.
        """
        start_time = time.time()
        process_id = os.getpid()
        try:
            output = await self.model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
            if isinstance(output.images, list) and len(output.images) > 0:
                image = output.images[0]
                buffered = BytesIO()
                image.save(buffered, format="JPEG")
                image_bytes = buffered.getvalue()
                end_time = time.time()
                return image_bytes
            else:
                return None
        except Exception as e:
            return None

async def queue_api_calls(sentence_mapping, character_dict, selected_style):
    """
    Generates images for all provided prompts in parallel using Ray actors.
    Parameters:
        - sentence_mapping (dict): Mapping between paragraph numbers and sentences.
        - character_dict (dict): Dictionary mapping characters to their descriptions.
        - selected_style (str): Selected illustration style.
    Returns:
        dict: A dictionary where keys are paragraph numbers and values are image data in bytes format.
    """
    prompts = []
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append((paragraph_number, prompt))

    num_prompts = len(prompts)
    num_actors = min(num_prompts, 20)
    model_actors = [ModelActor.remote() for _ in range(num_actors)]
    tasks = [model_actors[i % num_actors].generate_image.remote(prompt, f"Prompt {paragraph_number}") for i, (paragraph_number, prompt) in enumerate(prompts)]

    responses = await asyncio.gather(*[ray.get(task) for task in tasks])
    images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
    return images

def process_prompt(sentence_mapping, character_dict, selected_style):
    """
    Processes the provided prompts and generates images.
    Parameters:
        - sentence_mapping (dict): Mapping between paragraph numbers and sentences.
        - character_dict (dict): Dictionary mapping characters to their descriptions.
        - selected_style (str): Selected illustration style.
    Returns:
        dict: A dictionary where keys are paragraph numbers and values are image data in bytes format.
    """
    try:
        loop = asyncio.get_running_loop()
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

    cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
    return cmpt_return

gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[gr.JSON(label="Sentence Mapping"), gr.JSON(label="Character Dict"), gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")],
    outputs="json"
)

if __name__ == "__main__":
    gradio_interface.launch()