File size: 2,890 Bytes
f6b8b7e
13298a2
f6b8b7e
c1282a1
216a041
13298a2
c1282a1
b5ad13a
13298a2
f6b8b7e
b85438c
13298a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d1d03a
216a041
f466dd9
13298a2
 
 
 
216a041
 
 
13298a2
 
 
d26a101
13298a2
 
d26a101
f6b8b7e
c301a62
6449f8f
f466dd9
b5ad13a
c1282a1
13298a2
d26a101
216a041
6035350
1d0b035
e0ec116
 
f6b8b7e
e0ec116
 
6449f8f
6035350
e0ec116
b5ad13a
216a041
 
 
e0ec116
216a041
 
b5ad13a
6035350
f466dd9
 
 
d05fa5e
1adc78a
 
e0ec116
d05fa5e
6035350
58f74fc
f466dd9
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from diffusers import AutoPipelineForText2Image
from generate_propmts import generate_prompt
from PIL import Image
import asyncio
import threading
import traceback

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

class SchedulerWrapper:
    def __init__(self, scheduler):
        self.scheduler = scheduler
        self._step = threading.local()
        self._step.step = 0

    def step(self, *args, **kwargs):
        try:
            self._step.step += 1
            return self.scheduler.step(*args, **kwargs)
        except IndexError:
            self._step.step = 0
            return self.scheduler.step(*args, **kwargs)

# Wrap the scheduler
model.scheduler = SchedulerWrapper(model.scheduler)

async def generate_image(prompt):
    try:
        # Set a higher value for num_inference_steps
        num_inference_steps = 5  # Adjust this value as needed

        # Use the model to generate an image
        output = await asyncio.to_thread(
            model,
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=0.0,  # Typical value for guidance scale in image generation
            output_type="pil"  # Directly get PIL Image objects
        )

        # Check for output validity and return
        if output.images:
            return output.images[0]
        else:
            raise Exception("No images returned by the model.")
    except Exception as e:
        print(f"Error generating image: {e}")
        traceback.print_exc()
        return None  # Return None on error to handle it gracefully in the UI

async def inference(sentence_mapping, character_dict, selected_style):
    images = []
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append(prompt)
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    # Use asyncio.gather to run generate_image in parallel
    tasks = [generate_image(prompt) for prompt in prompts]
    images = await asyncio.gather(*tasks)

    # Filter out None values
    images = [image for image in images if image is not None]

    return images

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs=gr.Gallery(label="Generated Images")
)

if __name__ == "__main__":
    gradio_interface.launch()