File size: 2,631 Bytes
f6b8b7e
789e6b5
13298a2
789e6b5
 
 
216a041
b85438c
789e6b5
66e43e3
13298a2
216a041
f466dd9
789e6b5
 
 
 
 
 
 
 
 
 
 
c301a62
6449f8f
789e6b5
f466dd9
b5ad13a
789e6b5
d26a101
789e6b5
 
b0bcf89
e0ec116
b0bcf89
 
e0ec116
 
6449f8f
789e6b5
b0bcf89
 
789e6b5
 
 
 
 
 
 
 
 
 
 
 
 
 
b0bcf89
6035350
f466dd9
 
789e6b5
d05fa5e
1adc78a
 
e0ec116
d05fa5e
789e6b5
 
58f74fc
f466dd9
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor
import asyncio

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

async def generate_image(prompt):
    try:
        # Truncate prompt if necessary
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output: {output}")

        # Check if the model returned images
        if isinstance(output.images, list) and len(output.images) > 0:
            image = output.images[0]
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            return image_bytes
        else:
            raise Exception("No images returned by the model.")
            
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

async def process_prompt(sentence_mapping, character_dict, selected_style):
    images = {}
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append((paragraph_number, prompt))
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    loop = asyncio.get_event_loop()
    tasks = []
    
    with ThreadPoolExecutor() as executor:
        for paragraph_number, prompt in prompts:
            tasks.append(loop.run_in_executor(executor, generate_image, prompt))

        for paragraph_number, task in zip(sentence_mapping.keys(), await asyncio.gather(*tasks)):
            try:
                image = task
                if image:
                    images[paragraph_number] = image
            except Exception as e:
                print(f"Error processing paragraph {paragraph_number}: {e}")

    return images

gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json",
    concurrency_limit=10  # Allow up to 10 concurrent executions
)

if __name__ == "__main__":
    gradio_interface.launch()