File size: 3,179 Bytes
14e57b0
f6b8b7e
9872917
dacf4f2
789e6b5
9872917
b85438c
789e6b5
66e43e3
13298a2
216a041
f466dd9
06a3d1b
dacf4f2
9872917
9d2d2d3
9872917
789e6b5
 
 
9d2d2d3
 
 
 
 
 
 
 
 
c301a62
6449f8f
f466dd9
b5ad13a
789e6b5
d26a101
789e6b5
 
9872917
e0ec116
b0bcf89
 
e0ec116
 
de2c9e2
789e6b5
9872917
b0bcf89
06a3d1b
02161de
06a3d1b
789e6b5
06a3d1b
 
f45116f
 
c7a48c9
dacf4f2
b0bcf89
6035350
f466dd9
dacf4f2
de2c9e2
 
dacf4f2
 
9872917
f466dd9
789e6b5
d05fa5e
1adc78a
 
e0ec116
d05fa5e
c7a48c9
dacf4f2
ef9b8ab
f466dd9
 
c7a48c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from generate_prompts import generate_prompt
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import asyncio

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

async def generate_image(prompt):
    try:
        # Generate an image based on the prompt
        output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output: {output}")

        # Check if the model returned images
        if isinstance(output.images, list) and len(output.images) > 0:
            image = output.images[0]
            buffered = BytesIO()
            try:
                image.save(buffered, format="JPEG")
                image_bytes = buffered.getvalue()
                # Verify the image bytes
                print(f"Image bytes length: {len(image_bytes)}")
                return image_bytes
            except Exception as e:
                print(f"Error saving image: {e}")
                return None
        else:
            raise Exception("No images returned by the model.")
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

async def process_prompt(sentence_mapping, character_dict, selected_style):
    images = {}
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, character_dict, selected_style)
        prompts.append((paragraph_number, prompt))
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    # Create tasks for all prompts and run them concurrently
    tasks = [generate_image(prompt) for _, prompt in prompts]
    results = await asyncio.gather(*tasks)

    # Map results back to paragraphs
    for i, (paragraph_number, _) in enumerate(prompts):
        if i < len(results):
            images[paragraph_number] = results[i]
        else:
            print(f"Error: No result for paragraph {paragraph_number}")

    return images

# Helper function to generate a prompt based on the input
def generate_prompt(combined_sentence, character_dict, selected_style):
    characters = " ".join([" ".join(character) if isinstance(character, list) else character for character in character_dict.values()])
    return f"Make an illustration in {selected_style} style from: {characters}. {combined_sentence}"

# Gradio interface with high concurrency limit
gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json",
    concurrency_limit=20  # Set a high concurrency limit
).queue(default_concurrency_limit=20)

if __name__ == "__main__":
    gradio_interface.launch()  # No need for share=True for local testing