File size: 3,721 Bytes
c513221
86743ba
9da79fd
5e2c7ed
 
 
86743ba
b85438c
86743ba
 
 
 
 
 
 
 
 
 
 
3b7350e
 
 
 
 
 
 
 
 
 
 
 
 
86743ba
3b7350e
 
86743ba
3b7350e
 
 
 
86743ba
d26a101
86743ba
 
 
690f094
 
 
c7f120b
e284958
690f094
 
bdf16c0
86743ba
 
 
 
 
 
 
c7f120b
86743ba
c7f120b
690f094
081cd9c
5e2c7ed
a04441d
86743ba
 
 
 
 
 
081cd9c
690f094
bdf16c0
690f094
 
 
 
 
bdf16c0
 
 
f466dd9
c7f120b
bdf16c0
c7f120b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import multiprocessing
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr
import json

# Define a global variable to hold the model
model = None

def initialize_model():
    global model
    if model is None:  # Ensure the model is loaded only once per process
        print("Loading the model...")
        model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
        print("Model loaded successfully.")

def generate_image(prompt, prompt_name):
    try:
        print(f"Generating response for {prompt_name} with prompt: {prompt}")
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Output for {prompt_name}: {output}")

        # Check if the model returned images
        if isinstance(output.images, list) and len(output.images) > 0:
            image = output.images[0]
            buffered = BytesIO()
            try:
                image.save(buffered, format="JPEG")
                image_bytes = buffered.getvalue()
                print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
                return prompt_name, image_bytes
            except Exception as e:
                print(f"Error saving image for {prompt_name}: {e}")
                return prompt_name, None
        else:
            raise Exception(f"No images returned by the model for {prompt_name}.")
    except Exception as e:
        print(f"Error generating image for {prompt_name}: {e}")
        return prompt_name, None

def process_prompts(sentence_mapping, character_dict, selected_style):
    print(f"process_prompts called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
    
    prompts = []
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append((paragraph_number, prompt))
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    num_prompts = len(prompts)
    print(f"Number of prompts: {num_prompts}")

    # Limit the number of worker processes to the number of prompts
    with multiprocessing.Pool(processes=num_prompts, initializer=initialize_model) as pool:
        tasks = [(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
        results = pool.starmap(generate_image, tasks)

    images = {prompt_name: image for prompt_name, image in results}
    print(f"Images generated: {images}")
    return images

def process_prompt(sentence_mapping, character_dict, selected_style):
    print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
    # Check if inputs are already in dict form
    if isinstance(sentence_mapping, str):
        sentence_mapping = json.loads(sentence_mapping)
    if isinstance(character_dict, str):
        character_dict = json.loads(character_dict)
    return process_prompts(sentence_mapping, character_dict, selected_style)

gradio_interface = gr.Interface(
    fn=process_prompt,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json"
)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    gradio_interface.launch()
    print("Gradio interface launched.")