File size: 1,994 Bytes
f466dd9
 
a9b8939
9e09422
 
cd715cb
b5ad13a
 
a9b8939
 
6d1d03a
8a0f059
f466dd9
b5ad13a
a9b8939
8a0f059
b5ad13a
8a0f059
b5ad13a
f466dd9
b5ad13a
8a0f059
f466dd9
b5ad13a
8a0f059
b5ad13a
 
 
 
 
8a0f059
 
b5ad13a
 
 
 
 
ca1d41c
f466dd9
 
 
b5ad13a
 
 
 
 
 
 
f466dd9
 
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
import base64
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

def generate_image(text, sentence_mapping, character_dict, selected_style):
    try:
        prompt, _ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
        image = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        return img_str
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

def inference(text, sentence_mapping, character_dict, selected_style):
    images = {}
    # Here we assume `sentence_mapping` is a dictionary where keys are paragraph numbers and values are lists of sentences
    grouped_sentences = sentence_mapping

    with ThreadPoolExecutor() as executor:
        futures = {}
        for paragraph_number, sentences in grouped_sentences.items():
            combined_sentence = " ".join(sentences)
            futures[paragraph_number] = executor.submit(generate_image, combined_sentence, sentence_mapping, character_dict, selected_style)

        for paragraph_number, future in futures.items():
            images[paragraph_number] = future.result()

    return images

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.inputs.Textbox(label="Text"),
        gr.inputs.Textbox(label="Sentence Mapping"),
        gr.inputs.Textbox(label="Character Dict"),
        gr.inputs.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
    ],
    outputs="json"  # Return the dictionary of images
)

if __name__ == "__main__":
    gradio_interface.launch()