import gradio as gr import spaces from diffusers import AutoPipelineForImage2Image, StableDiffusionInstructPix2PixPipeline from loguru import logger from PIL import Image models = [ "stabilityai/sdxl-turbo", "stabilityai/stable-diffusion-3-medium-diffusers", "stabilityai/stable-diffusion-xl-refiner-1.0", "timbrooks/instruct-pix2pix", ] DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0" def load_pipeline(model): pipeline_type = ( StableDiffusionInstructPix2PixPipeline if model == "timbrooks/instruct-pix2pix" else AutoPipelineForImage2Image ) return pipeline_type.from_pretrained(model) load_pipeline(DEFAULT_MODEL).to("cuda") loaded_models = {DEFAULT_MODEL} def generate_image( model: str, prompt: str, init_image: Image.Image, strength: float, progress, ): logger.debug(f"Loading pipeline: {dict(model=model)}") pipe = load_pipeline(model).to("cuda") logger.debug(f"Generating image: {dict(prompt=prompt)}") additional_args = ( {} if model == "timbrooks/instruct-pix2pix" else dict(strength=strength) ) def progress_callback(pipe, step_index, timestep, callback_kwargs): logger.trace( f"Callback: {dict(num_timesteps=pipe.num_timesteps, step_index=step_index, timestep=timestep)}" ) progress((step_index + 1, pipe.num_timesteps)) return callback_kwargs images = pipe( prompt=prompt, image=init_image, callback_on_step_end=progress_callback, **additional_args, ).images return images[0] @spaces.GPU def gpu(*args, **kwargs): return generate_image(*args, **kwargs) @spaces.GPU(duration=180) def gpu_3min(*args, **kwargs): return generate_image(*args, **kwargs) @logger.catch(reraise=True) def generate( model: str, prompt: str, init_image: Image.Image, strength: float, progress=gr.Progress(), ): logger.info( f"Starting image generation: {dict(model=model, prompt=prompt, image=init_image, strength=strength)}" ) # Downscale the image init_image.thumbnail((1024, 1024)) # Cache the model files for the pipeline if model not in loaded_models: logger.debug(f"Caching pipeline: {dict(model=model)}") load_pipeline(model) loaded_models.add(model) gpu_runner = gpu_3min if model == "timbrooks/instruct-pix2pix" else gpu return gpu_runner(model, prompt, init_image, strength, progress) demo = gr.Interface( fn=generate, inputs=[ gr.Dropdown( label="Model", choices=models, value=DEFAULT_MODEL, allow_custom_value=True ), gr.Text(label="Prompt"), gr.Image(label="Init image", type="pil"), gr.Slider(label="Strength", minimum=0, maximum=1, value=0.3), ], outputs=[gr.Image(label="Output")], ) demo.launch()