File size: 2,892 Bytes
37615de
 
20d0c1e
37615de
 
 
 
 
1789e44
 
37615de
 
fbb30c1
 
2d7f410
 
 
 
 
 
 
 
 
 
fbb30c1
f9d0c71
 
 
 
b7a0604
 
 
 
 
 
 
 
2d7f410
b7a0604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb30c1
b7a0604
 
37615de
 
ef0ba0f
b7a0604
 
fbb30c1
 
 
37615de
 
 
 
 
 
 
 
 
 
 
 
 
 
2d7f410
 
 
 
 
 
 
fbb30c1
 
b7a0604
37615de
 
 
 
 
 
fbb30c1
37615de
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import spaces
from diffusers import AutoPipelineForImage2Image, StableDiffusionInstructPix2PixPipeline
from loguru import logger
from PIL import Image

models = [
    "stabilityai/sdxl-turbo",
    "stabilityai/stable-diffusion-3-medium-diffusers",
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    "timbrooks/instruct-pix2pix",
]
DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"


def load_pipeline(model):
    pipeline_type = (
        StableDiffusionInstructPix2PixPipeline
        if model == "timbrooks/instruct-pix2pix"
        else AutoPipelineForImage2Image
    )

    return pipeline_type.from_pretrained(model)


load_pipeline(DEFAULT_MODEL).to("cuda")
loaded_models = {DEFAULT_MODEL}


def generate_image(
    model: str,
    prompt: str,
    init_image: Image.Image,
    strength: float,
    progress,
):
    logger.debug(f"Loading pipeline: {dict(model=model)}")
    pipe = load_pipeline(model).to("cuda")

    logger.debug(f"Generating image: {dict(prompt=prompt)}")
    additional_args = (
        {} if model == "timbrooks/instruct-pix2pix" else dict(strength=strength)
    )

    def progress_callback(pipe, step_index, timestep, callback_kwargs):
        logger.trace(
            f"Callback: {dict(num_timesteps=pipe.num_timesteps, step_index=step_index, timestep=timestep)}"
        )
        progress((step_index + 1, pipe.num_timesteps))
        return callback_kwargs

    images = pipe(
        prompt=prompt,
        image=init_image,
        callback_on_step_end=progress_callback,
        **additional_args,
    ).images
    return images[0]


@spaces.GPU
def gpu(*args, **kwargs):
    return generate_image(*args, **kwargs)


@spaces.GPU(duration=180)
def gpu_3min(*args, **kwargs):
    return generate_image(*args, **kwargs)


@logger.catch(reraise=True)
def generate(
    model: str,
    prompt: str,
    init_image: Image.Image,
    strength: float,
    progress=gr.Progress(),
):
    logger.info(
        f"Starting image generation: {dict(model=model, prompt=prompt, image=init_image, strength=strength)}"
    )

    # Downscale the image
    init_image.thumbnail((1024, 1024))

    # Cache the model files for the pipeline
    if model not in loaded_models:
        logger.debug(f"Caching pipeline: {dict(model=model)}")

        load_pipeline(model)
        loaded_models.add(model)

    gpu_runner = gpu_3min if model == "timbrooks/instruct-pix2pix" else gpu

    return gpu_runner(model, prompt, init_image, strength, progress)


demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Dropdown(
            label="Model", choices=models, value=DEFAULT_MODEL, allow_custom_value=True
        ),
        gr.Text(label="Prompt"),
        gr.Image(label="Init image", type="pil"),
        gr.Slider(label="Strength", minimum=0, maximum=1, value=0.3),
    ],
    outputs=[gr.Image(label="Output")],
)

demo.launch()