import spaces import gradio as gr import numpy as np import random import functools import torch from diffusers import StableDiffusion3Pipeline from inference import run from peft import LoraConfig, get_peft_model, PeftModel pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16) pipe = pipe.to("cuda") distill_check = 'yresearch/swd-medium-6-steps' pipe.transformer = PeftModel.from_pretrained( pipe.transformer, distill_check, ) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 @spaces.GPU() def infer(prompt, seed, randomize_seed): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) sigmas = [1.0000, 0.9454, 0.8959, 0.7904, 0.7371, 0.6022] scales = [32, 48, 64, 80, 96, 128] images = run( pipe, prompt, sigmas=sigmas, scales=scales, num_inference_steps=6, guidance_scale=0.0, height=int(scales[0] * 8), width=int(scales[0] * 8), generator=generator, ).images return images examples = [ "An astronaut riding a green horse", 'Long-exposure night photography of a starry sky over a mountain range, with light trails.', "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "A portrait of a girl with blonde, tousled hair, blue eyes", ] css = """ #col-container { margin: 0 auto; max-width: 520px; } """ if torch.cuda.is_available(): power_device = "GPU" else: power_device = "CPU" with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown( f""" # ⚡ Scale-wise Distillation ⚡ # ⚡ Image Generation with 6-step SwD ⚡ This is a demo of [Scale-wise Distillation](https://yandex-research.github.io/invertible-cd/), a diffusion distillation method proposed in [Scale-wise Distillation of Diffusion Models](https://arxiv.org/abs/2406.14539) by [Yandex Research](https://github.com/yandex-research). Currently running on {power_device}. """ ) gr.Markdown( "If you enjoy the space, feel free to give a ⭐ to the Github Repo. [![GitHub Stars](https://img.shields.io/github/stars/yandex-research/invertible-cd?style=social)](https://github.com/yandex-research/invertible-cd)" ) with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0) result = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=False) gr.Examples( examples=examples, inputs=[prompt], cache_examples=False ) run_button.click( fn=infer, inputs=[prompt, seed, randomize_seed], outputs=[result] ) demo.queue().launch(share=False)