File size: 2,789 Bytes
f6b8b7e
a9b8939
f6b8b7e
e0ec116
c1282a1
 
b5ad13a
f6b8b7e
 
6d1d03a
d26a101
 
 
e0ec116
f466dd9
d26a101
 
 
 
 
 
 
 
 
 
 
e0ec116
d26a101
 
 
 
 
f6b8b7e
c301a62
6449f8f
d26a101
f466dd9
b5ad13a
c1282a1
d26a101
 
f466dd9
c0cd59a
6035350
1d0b035
e0ec116
 
f6b8b7e
e0ec116
 
6449f8f
6035350
e0ec116
b5ad13a
 
6035350
e0ec116
6035350
e0ec116
 
 
6035350
e0ec116
6035350
c1282a1
b5ad13a
6035350
f466dd9
 
 
d05fa5e
1adc78a
 
e0ec116
d05fa5e
6035350
f466dd9
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from diffusers import AutoPipelineForText2Image
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
import traceback

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

# Create a thread-local storage for step indices
scheduler_step_storage = threading.local()

def generate_image(prompt):
    try:
        # Initialize step index per thread if not already set
        if not hasattr(scheduler_step_storage, 'step'):
            scheduler_step_storage.step = 0
        
        # Use the thread-local step index
        output = model(
            prompt=prompt, 
            num_inference_steps=1,  # Add a sensible default for inference steps
            guidance_scale=0.0, 
            output_type="pil"  # Directly get PIL Image objects
        )

        # Increment the step index after generating the image
        scheduler_step_storage.step += 1  
        
        # Check for output validity and return
        if output.images:
            return output.images[0]
        else:
            raise Exception("No images returned by the model.")
    
    except Exception as e:
        print(f"Error generating image: {e}")
        traceback.print_exc()
        return None  # Return None on error to handle it gracefully in the UI


def inference(sentence_mapping, character_dict, selected_style):
    images = []
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append(prompt)
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(generate_image, prompt) for prompt in prompts]

        for future in as_completed(futures):
            try:
                image = future.result()
                if image:
                    images.append(image)
            except Exception as e:
                print(f"Error processing prompt: {e}")
                traceback.print_exc()

    return images

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs=gr.Gallery(label="Generated Images")

if __name__ == "__main__":
    gradio_interface.launch()