File size: 2,217 Bytes
f6b8b7e
b85438c
f6b8b7e
c1282a1
216a041
c1282a1
b5ad13a
9b06634
f6b8b7e
9b06634
b85438c
6d1d03a
216a041
f466dd9
9b06634
216a041
 
 
9b06634
e15ee10
 
d26a101
 
f6b8b7e
c301a62
6449f8f
f466dd9
b5ad13a
c1282a1
e15ee10
b85438c
d26a101
216a041
6035350
1d0b035
e0ec116
 
f6b8b7e
e0ec116
 
6449f8f
6035350
e0ec116
b5ad13a
216a041
 
 
e0ec116
216a041
 
b5ad13a
6035350
f466dd9
 
 
d05fa5e
1adc78a
 
e0ec116
d05fa5e
6035350
58f74fc
f466dd9
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
from diffusers import AutoPipelineForText2Image, EulerAncestralDiscreteScheduler
from generate_propmts import generate_prompt
from PIL import Image
import asyncio
import traceback

# Load the model with a different scheduler
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
model.scheduler = DPMSolverMultistepScheduler.from_config(model.scheduler.config)  # Changed scheduler


async def generate_image(prompt):
    try:
        num_inference_steps = 5  # You can adjust this
        output = await asyncio.to_thread(
            model,
            prompt=prompt,
            num_inference_steps=num_inference_steps, 
            guidance_scale=0.0,
            output_type="pil"
        )
        if output.images:
            return output.images[0]
        else:
            raise Exception("No images returned by the model.")
    except Exception as e:
        print(f"Error generating image: {e}")
        traceback.print_exc()
        return None
 

async def inference(sentence_mapping, character_dict, selected_style):
    images = []
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append(prompt)
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    # Use asyncio.gather to run generate_image in parallel
    tasks = [generate_image(prompt) for prompt in prompts]
    images = await asyncio.gather(*tasks)

    # Filter out None values
    images = [image for image in images if image is not None]

    return images

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs=gr.Gallery(label="Generated Images")
)

if __name__ == "__main__":
    gradio_interface.launch()