File size: 4,261 Bytes
c513221
109adde
441106f
9da79fd
5e2c7ed
 
 
5d9bf5a
97389f0
b85438c
5d9bf5a
86743ba
5d9bf5a
 
 
 
 
 
 
 
441106f
 
5d9bf5a
441106f
5d9bf5a
441106f
5d9bf5a
 
 
 
 
 
 
441106f
 
 
5d9bf5a
 
441106f
5d9bf5a
 
441106f
5d9bf5a
441106f
5d9bf5a
3b7350e
97389f0
 
d26a101
97389f0
109adde
690f094
109adde
690f094
 
c7f120b
fb8ec57
690f094
 
bdf16c0
fb8ec57
5d9bf5a
30c04e8
441106f
5d9bf5a
c7f120b
cfeca25
c7f120b
690f094
081cd9c
5e2c7ed
a04441d
109adde
 
 
 
 
 
 
97389f0
 
109adde
 
081cd9c
690f094
bdf16c0
690f094
 
 
 
 
bdf16c0
 
 
f466dd9
c7f120b
c14304d
c7f120b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import asyncio
import time
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr
import ray
from ray.util import ActorPool

ray.init()

@ray.remote
class ModelActor:
    def __init__(self):
        print("Loading the model...")
        self.model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
        print("Model loaded successfully.")

    def generate_image(self, prompt, prompt_name):
        start_time = time.time()
        process_id = os.getpid()
        try:
            print(f"[{process_id}] Generating response for {prompt_name} with prompt: {prompt}")
            output = self.model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
            print(f"[{process_id}] Output for {prompt_name}: {output}")

            if isinstance(output.images, list) and len(output.images) > 0:
                image = output.images[0]
                buffered = BytesIO()
                try:
                    image.save(buffered, format="JPEG")
                    image_bytes = buffered.getvalue()
                    end_time = time.time()
                    print(f"[{process_id}] Image bytes length for {prompt_name}: {len(image_bytes)}")
                    print(f"[{process_id}] Time taken for {prompt_name}: {end_time - start_time} seconds")
                    return image_bytes
                except Exception as e:
                    print(f"[{process_id}] Error saving image for {prompt_name}: {e}")
                    return None
            else:
                raise Exception(f"[{process_id}] No images returned by the model for {prompt_name}.")
        except Exception as e:
            print(f"[{process_id}] Error generating image for {prompt_name}: {e}")
            return None

def create_actor_pool(num_actors):
    return ActorPool([ModelActor.remote() for _ in range(num_actors)])

async def queue_api_calls(sentence_mapping, character_dict, selected_style, pool):
    print(f"queue_api_calls invoked with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
    prompts = []

    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
        prompt = generate_prompt(combined_sentence, character_dict, selected_style)
        prompts.append((paragraph_number, prompt))
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    tasks = [pool.submit(lambda actor, p=prompt, pn=f"Prompt {paragraph_number}": actor.generate_image.remote(p, pn)) for paragraph_number, prompt in prompts]
    print("Tasks created for image generation.")

    responses = await asyncio.gather(*[asyncio.to_thread(ray.get, task) for task in tasks])
    print("Responses received from image generation tasks.")

    images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
    print(f"Images generated: {images}")
    return images

def process_prompt(sentence_mapping, character_dict, selected_style):
    print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
    try:
        loop = asyncio.get_running_loop()
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
    print("Event loop created.")

    pool = create_actor_pool(min(20, max(1, len(sentence_mapping))))  # Create pool with dynamic size
    cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style, pool))
    print(f"process_prompt completed with return value: {cmpt_return}")
    return cmpt_return

gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json"
)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    gradio_interface.launch()
    print("Gradio interface launched.")