File size: 2,144 Bytes
3486e1a
f466dd9
 
a9b8939
9e09422
 
cd715cb
b5ad13a
 
a9b8939
 
6d1d03a
c0cd59a
f466dd9
c0cd59a
a9b8939
8a0f059
3486e1a
8a0f059
3486e1a
f466dd9
b5ad13a
8a0f059
f466dd9
c0cd59a
8a0f059
c0cd59a
3486e1a
 
 
 
 
 
 
b5ad13a
 
 
8a0f059
 
3486e1a
b5ad13a
 
 
 
ca1d41c
f466dd9
 
 
d05fa5e
3486e1a
 
d05fa5e
 
3486e1a
f466dd9
 
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import json
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
import base64
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

def generate_image(text, sentence_mapping, character_dict, selected_style):
    try:
        prompt, _ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
        image = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
        buffered = BytesIO()
        buffered.write(image.tobytes())
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        return img_str
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

def inference(sentence_mapping, character_dict, selected_style):
    images = {}
    print(f'sentence_mapping:{sentence_mapping}, character_dict:{character_dict}, selected_style:{selected_style}')
    
    # Parse sentence_mapping JSON string into a dictionary
    try:
        grouped_sentences = json.loads(sentence_mapping)
    except json.JSONDecodeError as e:
        print(f"Error parsing JSON: {e}")
        return {"error": "Invalid JSON input for sentence_mapping"}

    with ThreadPoolExecutor() as executor:
        futures = {}
        for paragraph_number, sentences in grouped_sentences.items():
            combined_sentence = " ".join(sentences)
            futures[paragraph_number] = executor.submit(generate_image, combined_sentence, grouped_sentences, character_dict, selected_style)

        for paragraph_number, future in futures.items():
            images[paragraph_number] = future.result()

    return images

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.Textbox(label="Sentence Mapping (JSON)"),
        gr.Textbox(label="Character Dict (JSON)"),
        gr.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
    ],
    outputs="text"
)

if __name__ == "__main__":
    gradio_interface.launch()