File size: 3,175 Bytes
14e57b0
f6b8b7e
dacf4f2
789e6b5
9872917
b85438c
789e6b5
66e43e3
13298a2
c9b9787
f466dd9
c9b9787
dacf4f2
9d2d2d3
9872917
789e6b5
 
 
9d2d2d3
 
 
c9b9787
9d2d2d3
 
c9b9787
9d2d2d3
c301a62
c9b9787
f466dd9
c9b9787
789e6b5
d26a101
789e6b5
 
9872917
e0ec116
b0bcf89
 
e0ec116
 
de2c9e2
789e6b5
9872917
b0bcf89
06a3d1b
c9b9787
06a3d1b
789e6b5
06a3d1b
 
f45116f
 
c7a48c9
dacf4f2
b0bcf89
6035350
f466dd9
dacf4f2
de2c9e2
 
dacf4f2
 
9872917
f466dd9
789e6b5
d05fa5e
1adc78a
 
e0ec116
d05fa5e
c7a48c9
dacf4f2
ef9b8ab
f466dd9
 
9e12e12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from generate_prompts import generate_prompt
import gradio as gr
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import asyncio

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

async def generate_image(prompt, prompt_name):
    try:
        print(f"Generating image for {prompt_name}")
        output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)

        # Check if the model returned images
        if isinstance(output.images, list) and len(output.images) > 0:
            image = output.images[0]
            buffered = BytesIO()
            try:
                image.save(buffered, format="JPEG")
                image_bytes = buffered.getvalue()
                print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
                return image_bytes
            except Exception as e:
                print(f"Error saving image for {prompt_name}: {e}")
                return None
        else:
            raise Exception(f"No images returned by the model for {prompt_name}.")
    except Exception as e:
        print(f"Error generating image for {prompt_name}: {e}")
        return None

async def process_prompt(sentence_mapping, character_dict, selected_style):
    images = {}
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, character_dict, selected_style)
        prompts.append((paragraph_number, prompt))
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    # Create tasks for all prompts and run them concurrently
    tasks = [generate_image(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
    results = await asyncio.gather(*tasks)

    # Map results back to paragraphs
    for i, (paragraph_number, _) in enumerate(prompts):
        if i < len(results):
            images[paragraph_number] = results[i]
        else:
            print(f"Error: No result for paragraph {paragraph_number}")

    return images

# Helper function to generate a prompt based on the input
def generate_prompt(combined_sentence, character_dict, selected_style):
    characters = " ".join([" ".join(character) if isinstance(character, list) else character for character in character_dict.values()])
    return f"Make an illustration in {selected_style} style from: {characters}. {combined_sentence}"

# Gradio interface with high concurrency limit
gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json",
    concurrency_limit=20  # Set a high concurrency limit
).queue(default_concurrency_limit=20)

if __name__ == "__main__":
    gradio_interface.launch()