image-to-image / app.py
dgoot's picture
Tune GPU duration based on model
fbb30c1
raw
history blame
2.33 kB
import gradio as gr
import spaces
from diffusers import AutoPipelineForImage2Image, StableDiffusionInstructPix2PixPipeline
from loguru import logger
from PIL import Image
models = [
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/stable-diffusion-xl-refiner-1.0",
"timbrooks/instruct-pix2pix",
]
DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"
@spaces.GPU
def gpu(fn):
return fn()
@spaces.GPU(duration=180)
def gpu_3min(fn):
return fn()
@logger.catch(reraise=True)
def generate(
model: str,
prompt: str,
init_image: Image.Image,
strength: float,
progress=gr.Progress(),
):
logger.info(
f"Starting image generation: {dict(model=model, prompt=prompt, image=init_image, strength=strength)}"
)
# Downscale the image
init_image.thumbnail((1024, 1024))
def progress_callback(pipe, step_index, timestep, callback_kwargs):
logger.trace(
f"Callback: {dict(num_timesteps=pipe.num_timesteps, step_index=step_index, timestep=timestep)}"
)
progress((step_index + 1, pipe.num_timesteps))
return callback_kwargs
pipeline_type = (
StableDiffusionInstructPix2PixPipeline
if model == "timbrooks/instruct-pix2pix"
else AutoPipelineForImage2Image
)
logger.debug(f"Loading pipeline: {dict(model=model)}")
pipe = pipeline_type.from_pretrained(model).to("cuda")
logger.debug(f"Generating image: {dict(prompt=prompt)}")
additional_args = (
{} if model == "timbrooks/instruct-pix2pix" else dict(strength=strength)
)
gpu_runner = gpu_3min if model == "timbrooks/instruct-pix2pix" else gpu
images = gpu_runner(
lambda: pipe(
prompt=prompt,
image=init_image,
callback_on_step_end=progress_callback,
**additional_args,
).images
)
return images[0]
demo = gr.Interface(
fn=generate,
inputs=[
gr.Dropdown(
label="Model", choices=models, value=DEFAULT_MODEL, allow_custom_value=True
),
gr.Text(label="Prompt"),
gr.Image(label="Init image", type="pil"),
gr.Slider(label="Strength", minimum=0, maximum=1, value=0.3),
],
outputs=[gr.Image(label="Output")],
)
demo.launch()