File size: 2,768 Bytes
f466dd9
 
a9b8939
9e09422
cd715cb
b5ad13a
1adc78a
b5ad13a
a9b8939
 
6d1d03a
c301a62
 
 
 
 
8a555c8
 
c301a62
 
c0cd59a
f466dd9
c0cd59a
a97be70
c301a62
 
8a555c8
1d0b035
 
8a555c8
c301a62
 
 
 
 
 
 
 
 
 
f466dd9
b5ad13a
8a0f059
f466dd9
c0cd59a
8a0f059
1d0b035
1adc78a
 
b5ad13a
 
 
8a0f059
 
1adc78a
b5ad13a
 
 
 
ca1d41c
f466dd9
 
 
d05fa5e
1adc78a
 
d05fa5e
 
c301a62
f466dd9
 
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor
import json

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

# Helper function to truncate prompt to fit the model's maximum sequence length
def truncate_prompt(prompt, max_length=77):
    tokens = prompt.split()
    if len(tokens) > max_length:
        return ' '.join(tokens[:max_length])
    print("len of tokens:", len(tokens))
    print("len of tokens:", len(prompt))
    return prompt

def generate_image(text, sentence_mapping, character_dict, selected_style):
    try:
        prompt, _ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
        print(f"Generated prompt: {prompt}")
        # Truncate prompt if necessary
        prompt = truncate_prompt(prompt)
        print(f"truncate_prompt: {prompt}")
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output: {output}")
        print("len of output:", len(output))
        # Check if the model returned images
        if output.images:
            image = output.images[0]
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            return image_bytes
        else:
            raise Exception("No images returned by the model.")
            
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

def inference(sentence_mapping, character_dict, selected_style):
    images = {}
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    # Here we assume `sentence_mapping` is a dictionary where keys are paragraph numbers and values are lists of sentences
    grouped_sentences = sentence_mapping

    with ThreadPoolExecutor() as executor:
        futures = {}
        for paragraph_number, sentences in grouped_sentences.items():
            combined_sentence = " ".join(sentences)
            futures[paragraph_number] = executor.submit(generate_image, combined_sentence, sentence_mapping, character_dict, selected_style)

        for paragraph_number, future in futures.items():
            images[paragraph_number] = future.result()

    return images

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
    ],
    outputs="json"
)

if __name__ == "__main__":
    gradio_interface.launch()