File size: 3,361 Bytes
5e2c7ed
690f094
9da79fd
5e2c7ed
 
 
b85438c
0cdadc9
6dc4bbd
13298a2
8551936
f466dd9
5e2c7ed
8551936
690f094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f466dd9
c9b9787
789e6b5
d26a101
5e2c7ed
690f094
 
bdf16c0
690f094
 
 
5abcbb9
8551936
690f094
 
bdf16c0
 
8551936
 
bdf16c0
5e2c7ed
bdf16c0
690f094
081cd9c
5e2c7ed
 
bdf16c0
5e2c7ed
 
 
 
 
 
bdf16c0
5e2c7ed
 
081cd9c
690f094
 
bdf16c0
690f094
 
 
 
 
bdf16c0
 
 
f466dd9
bdf16c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import asyncio
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

def generate_image(prompt, prompt_name):
    try:
        print(f"Generating response for {prompt_name}")
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)

        # Check if the model returned images
        if isinstance(output.images, list) and len(output.images) > 0:
            image = output.images[0]
            buffered = BytesIO()
            try:
                image.save(buffered, format="JPEG")
                image_bytes = buffered.getvalue()
                print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
                return image_bytes
            except Exception as e:
                print(f"Error saving image for {prompt_name}: {e}")
                return None
        else:
            raise Exception(f"No images returned by the model for {prompt_name}.")
    except Exception as e:
        print(f"Error generating image for {prompt_name}: {e}")
        return None

async def queue_api_calls(sentence_mapping, character_dict, selected_style):
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        print(f'combined_sentence: {combined_sentence}, character_dict: {character_dict}, selected_style: {selected_style}')
        prompt = generate_prompt(combined_sentence, character_dict, selected_style)
        prompts.append((paragraph_number, prompt))
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    # Generate images for each prompt in parallel
    loop = asyncio.get_running_loop()
    tasks = [loop.run_in_executor(None, generate_image, prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
    responses = await asyncio.gather(*tasks)
    
    images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
    return images

def process_prompt(sentence_mapping, character_dict, selected_style):
    try:
        # See if there is a loop already running. If there is, reuse it.
        loop = asyncio.get_running_loop()
    except RuntimeError:
        # Create new event loop if one is not running
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

    # This sends the prompts to function that sets up the async calls. Once all the calls to the API complete, it returns a list of the gr.Textbox with value= set.
    cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
    return cmpt_return

# Gradio interface with high concurrency limit
gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json"
)

if __name__ == "__main__":
    gradio_interface.launch()