File size: 3,696 Bytes
c513221
109adde
5c3986b
5e2c7ed
6b1b953
6706e0b
5e2c7ed
eb48f29
86743ba
5c3986b
 
6b1b953
 
 
 
 
 
5d9bf5a
33d78b0
eb48f29
6b1b953
 
 
efca12e
6b1b953
efca12e
6b1b953
6292767
 
 
 
5c3986b
eb48f29
 
 
 
 
 
194be56
6292767
eb48f29
 
 
3b7350e
834f7ba
5c3986b
28413d5
690f094
 
d253f4a
690f094
5c3986b
fd77b23
33d78b0
5c3986b
33d78b0
 
 
 
 
9cd3a95
cfeca25
efca12e
690f094
081cd9c
5e2c7ed
5c3986b
 
 
 
109adde
 
5c3986b
109adde
 
 
5c3986b
109adde
834f7ba
efca12e
109adde
081cd9c
690f094
bdf16c0
5c3986b
 
 
 
 
bdf16c0
eb48f29
bdf16c0
f466dd9
5c3986b
630a72e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import asyncio
from concurrent.futures import ProcessPoolExecutor
from io import BytesIO
from PIL import Image
from diffusers import StableDiffusionPipeline
import gradio as gr
from generate_prompts import generate_prompt

# Load the model once at the start
print("Loading the Stable Diffusion model...")
try:
    model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

def generate_image(prompt, prompt_name):
    try:
        if model is None:
            raise ValueError("Model not loaded properly.")
        
        print(f"Generating image for {prompt_name} with prompt: {prompt}")
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output for {prompt_name}: {output}")

        if output is None:
            raise ValueError(f"Model returned None for {prompt_name}")

        if hasattr(output, 'images') and output.images:
            print(f"Image generated for {prompt_name}")
            image = output.images[0]
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            return image_bytes
        else:
            print(f"No images found in model output for {prompt_name}")
            raise ValueError(f"No images found in model output for {prompt_name}")
    except Exception as e:
        print(f"An error occurred while generating image for {prompt_name}: {e}")
        return None

async def queue_api_calls(sentence_mapping, character_dict, selected_style):
    print("Starting to queue API calls...")
    prompts = []
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append((paragraph_number, prompt))
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    loop = asyncio.get_running_loop()
    with ProcessPoolExecutor() as pool:
        tasks = [
            loop.run_in_executor(pool, generate_image, prompt, f"Prompt {paragraph_number}")
            for paragraph_number, prompt in prompts
        ]
        responses = await asyncio.gather(*tasks)
    
    images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
    print("Finished queuing API calls. Generated images: ", images)
    return images

def process_prompt(sentence_mapping, character_dict, selected_style):
    print("Processing prompt...")
    print(f"Sentence Mapping: {sentence_mapping}")
    print(f"Character Dict: {character_dict}")
    print(f"Selected Style: {selected_style}")
    try:
        loop = asyncio.get_running_loop()
        print("Using existing event loop.")
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        print("Created new event loop.")

    cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
    print("Prompt processing complete. Generated images: ", cmpt_return)
    return cmpt_return

gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json"
).queue(default_concurrency_limit=20)  # Set concurrency limit if needed

if __name__ == "__main__":
    print("Launching Gradio interface...")
    gradio_interface.launch()