File size: 3,237 Bytes
bc0d978
0cdadc9
690f094
 
 
b85438c
0cdadc9
6dc4bbd
13298a2
081cd9c
f466dd9
c9b9787
690f094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f466dd9
c9b9787
789e6b5
d26a101
690f094
 
 
 
081cd9c
690f094
 
 
 
 
 
081cd9c
690f094
 
 
 
 
bc0d978
690f094
 
 
 
 
 
081cd9c
690f094
081cd9c
690f094
 
 
 
081cd9c
690f094
 
 
 
 
 
 
 
 
 
 
f466dd9
 
690f094
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import asyncio
from generate_propmts import generate_prompt

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

async def generate_image(prompt, prompt_name):
    try:
        print(f"Generating image for {prompt_name}")
        output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)

        # Check if the model returned images
        if isinstance(output.images, list) and len(output.images) > 0:
            image = output.images[0]
            buffered = BytesIO()
            try:
                image.save(buffered, format="JPEG")
                image_bytes = buffered.getvalue()
                print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
                return image_bytes
            except Exception as e:
                print(f"Error saving image for {prompt_name}: {e}")
                return None
        else:
            raise Exception(f"No images returned by the model for {prompt_name}.")
    except Exception as e:
        print(f"Error generating image for {prompt_name}: {e}")
        return None

async def process_prompt(sentence_mapping, character_dict, selected_style):
    images = {}
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, character_dict, selected_style)
        prompts.append((paragraph_number, prompt))
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    print(f'prompts: {prompts}')
    # Create tasks for all prompts and run them concurrently
    tasks = [generate_image(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
    print(f'tasks: {tasks}')
    results = await asyncio.gather(*tasks)

    # Map results back to paragraphs
    for i, (paragraph_number, _) in enumerate(prompts):
        if i < len(results):
            images[paragraph_number] = results[i]
        else:
            print(f"Error: No result for paragraph {paragraph_number}")

    return images

# Helper function to generate a prompt based on the input
def generate_prompt(combined_sentence, character_dict, selected_style):
    characters = " ".join([" ".join(character) if isinstance(character, list) else character for character in character_dict.values()])
    return f"Make an illustration in {selected_style} style from: {characters}. {combined_sentence}"

# Gradio interface with high concurrency limit
gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json",
    concurrency_limit=20  # Set a high concurrency limit
).queue(default_concurrency_limit=20)

if __name__ == "__main__":
    gradio_interface.launch()