File size: 2,218 Bytes
f6b8b7e
13298a2
66e43e3
c1282a1
216a041
b85438c
13298a2
 
 
66e43e3
 
 
 
 
 
 
13298a2
66e43e3
 
 
 
 
 
13298a2
 
66e43e3
 
 
6d1d03a
216a041
f466dd9
66e43e3
216a041
 
 
13298a2
66e43e3
 
d26a101
 
f6b8b7e
c301a62
6449f8f
f466dd9
b5ad13a
66e43e3
d26a101
216a041
6035350
e0ec116
 
 
6449f8f
6035350
216a041
 
 
6035350
f466dd9
 
 
d05fa5e
1adc78a
 
e0ec116
d05fa5e
6035350
58f74fc
f466dd9
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from diffusers import AutoPipelineForText2Image
from transformers import AutoTokenizer
from PIL import Image
import asyncio

class SchedulerWrapper:
    def __init__(self, scheduler):
        self.scheduler = scheduler
        
    def __getattr__(self, name):
        return getattr(self.scheduler, name)
        
    @property
    def timesteps(self):
        return self.scheduler.timesteps

    def set_timesteps(self, timesteps):
        self.scheduler.set_timesteps(timesteps)

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("stabilityai/sdxl-turbo")
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

# Wrap the scheduler
scheduler = model.scheduler
wrapped_scheduler = SchedulerWrapper(scheduler)
model.scheduler = wrapped_scheduler

async def generate_image(prompt):
    try:
        num_inference_steps = 5
        output = await asyncio.to_thread(
            model,
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=0.0,
            output_type="pil"
        )
        if output.images:
            return output.images[0]
        else:
            raise Exception("No images returned by the model.")
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

async def inference(sentence_mapping, character_dict, selected_style):
    images = []
    prompts = []
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append(prompt)
    tasks = [generate_image(prompt) for prompt in prompts]
    images = await asyncio.gather(*tasks)
    images = [image for image in images if image is not None]
    return images

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs=gr.Gallery(label="Generated Images")
)

if __name__ == "__main__":
    gradio_interface.launch()