File size: 4,868 Bytes
5e2c7ed
690f094
9da79fd
5e2c7ed
 
 
3b7350e
b85438c
0cdadc9
c7f120b
6dc4bbd
c7f120b
13298a2
3b7350e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d26a101
5e2c7ed
c7f120b
690f094
bdf16c0
690f094
 
 
c7f120b
3b7350e
690f094
 
bdf16c0
 
3b7350e
 
c7f120b
bdf16c0
c7f120b
 
bdf16c0
c7f120b
690f094
081cd9c
5e2c7ed
c7f120b
5e2c7ed
bdf16c0
5e2c7ed
 
 
 
 
c7f120b
5e2c7ed
3b7350e
 
 
bdf16c0
5e2c7ed
c7f120b
5e2c7ed
081cd9c
690f094
 
bdf16c0
690f094
 
 
 
 
bdf16c0
 
 
f466dd9
c7f120b
bdf16c0
c7f120b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import asyncio
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr
import threading

# Load the model once outside of the function
print("Loading the model...")
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")

# Create a thread-local storage object
thread_local = threading.local()

class Scheduler:
    def __init__(self):
        self._step = threading.local()
        self._step.step = None

    def _init_step_index(self):
        self._step.step = 0

    @property
    def step(self):
        return self._step.step

    def step_process(self):
        self._step.step += 1

scheduler = Scheduler()

def generate_image(prompt, prompt_name):
    try:
        # Initialize step index for the current thread
        if scheduler.step is None:
            scheduler._init_step_index()
        
        print(f"Initial step index for {prompt_name}: {scheduler.step}")
        print(f"Generating response for {prompt_name} with prompt: {prompt}")
        
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        
        # Update and print step index
        scheduler.step_process()
        print(f"Updated step index for {prompt_name}: {scheduler.step}")

        print(f"Output for {prompt_name}: {output}")

        # Check if the model returned images
        if isinstance(output.images, list) and len(output.images) > 0:
            image = output.images[0]
            buffered = BytesIO()
            try:
                image.save(buffered, format="JPEG")
                image_bytes = buffered.getvalue()
                print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
                return image_bytes
            except Exception as e:
                print(f"Error saving image for {prompt_name}: {e}")
                return None
        else:
            raise Exception(f"No images returned by the model for {prompt_name}.")
    except Exception as e:
        print(f"Error generating image for {prompt_name}: {e}")
        return None

async def queue_api_calls(sentence_mapping, character_dict, selected_style):
    print(f"queue_api_calls invoked with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
        prompt = generate_prompt(combined_sentence, character_dict, selected_style)  # Correct prompt generation
        prompts.append((paragraph_number, prompt))
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    # Generate images for each prompt in parallel
    loop = asyncio.get_running_loop()
    tasks = [loop.run_in_executor(None, generate_image, prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
    print("Tasks created for image generation.")
    responses = await asyncio.gather(*tasks)
    print("Responses received from image generation tasks.")

    images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
    print(f"Images generated: {images}")
    return images

def process_prompt(sentence_mapping, character_dict, selected_style):
    print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
    try:
        # See if there is a loop already running. If there is, reuse it.
        loop = asyncio.get_running_loop()
    except RuntimeError:
        # Create new event loop if one is not running
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
    print("Event loop created.")

    # Initialize thread-local variables
    scheduler._init_step_index()

    # This sends the prompts to function that sets up the async calls. Once all the calls to the API complete, it returns a list of the gr.Textbox with value= set.
    cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
    print(f"process_prompt completed with return value: {cmpt_return}")
    return cmpt_return

# Gradio interface with high concurrency limit
gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json"
)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    gradio_interface.launch()
    print("Gradio interface launched.")