File size: 3,130 Bytes
081cd9c
bc0d978
0cdadc9
cd42042
b85438c
0cdadc9
6dc4bbd
13298a2
081cd9c
f466dd9
c9b9787
081cd9c
6dc4bbd
bc0d978
 
 
f466dd9
c9b9787
789e6b5
d26a101
081cd9c
 
 
 
 
 
 
 
 
 
 
 
 
 
bc0d978
 
081cd9c
bc0d978
 
 
 
081cd9c
 
 
 
 
bc0d978
081cd9c
 
 
 
e4c2663
081cd9c
 
 
 
 
 
 
 
 
 
 
 
 
 
f466dd9
 
bc0d978
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import asyncio
import gradio as gr
from diffusers import AutoPipelineForText2Image
from generate_prompts import generate_prompt

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

async def generate_image(prompt, prompt_name):
    try:
        print(f"Generating image for {prompt_name}")
        output = await model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        image = output.images[0]
        img_bytes = image.tobytes()
        print(f"Image bytes length for {prompt_name}: {len(img_bytes)}")
        return img_bytes
    except Exception as e:
        print(f"Error generating image for {prompt_name}: {e}")
        return None

async def queue_image_calls(prompts):
    tasks = [generate_image(prompts[i], f"Prompt {i}") for i in range(len(prompts))]
    responses = await asyncio.gather(*tasks)
    return responses

def async_image_generation(prompts):
    try:
        loop = asyncio.get_running_loop()
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
    results = loop.run_until_complete(queue_image_calls(prompts))
    return results

def gradio_interface(sentence_mapping, character_dict, selected_style):
    prompts = generate_prompt(sentence_mapping, character_dict, selected_style)
    image_bytes_list = async_image_generation(prompts)
    outputs = [gr.Image.update(value=img_bytes) if img_bytes else gr.Image.update(value=None) for img_bytes in image_bytes_list]
    return outputs

# Gradio Interface
def update_images(sentence_mapping, character_dict, selected_style):
    prompts = generate_prompt(sentence_mapping, character_dict, selected_style)
    image_bytes_list = async_image_generation(prompts)
    return image_bytes_list

with gr.Blocks() as demo:
    sentence_mapping_input = gr.Textbox(label="Sentence Mapping")
    character_dict_input = gr.Textbox(label="Character Dictionary")
    selected_style_input = gr.Textbox(label="Selected Style")
    
    output_images = gr.Gallery(label="Generated Images")

    def generate_and_update_images(sentence_mapping, character_dict, selected_style):
        image_bytes_list = update_images(sentence_mapping, character_dict, selected_style)
        return [gr.Image.update(value=img_bytes) if img_bytes else gr.Image.update(value=None) for img_bytes in image_bytes_list]

    sentence_mapping_input.change(fn=generate_and_update_images, 
                                  inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
                                  outputs=output_images)
    character_dict_input.change(fn=generate_and_update_images, 
                                inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
                                outputs=output_images)
    selected_style_input.change(fn=generate_and_update_images, 
                                inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
                                outputs=output_images)

if __name__ == "__main__":
    demo.launch()