Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,529 Bytes
c2c42ca 4ac12a1 c2c42ca 4ac12a1 3bd17ee 4ac12a1 e10dc6d c2c42ca 4ac12a1 c2c42ca 4ac12a1 3bd17ee 4ac12a1 c2c42ca 4ac12a1 e10dc6d 143f063 34cb1b5 4ac12a1 c2c42ca 4ac12a1 c2c42ca 4ac12a1 c2c42ca 4ac12a1 c2c42ca 143f063 c2c42ca 34cb1b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
### SDXL Turbo ####
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
#pipe_turbo.to("cuda")
### SDXL Lightning ###
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
#pipe_lightning.to("cuda")
### Hyper SDXL ###
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
#pipe_hyper.to("cuda")
del unet
@spaces.GPU
def run_comparison(prompt):
image_turbo.to("cuda")
image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
image_turbo.to("cpu")
image_lightning.to("cuda")
image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
image_lightning.to("cpu")
image_hyper.to("cuda")
image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
image_turbo.to("cpu")
return image_turbo, image_lightning, image_hyper
css = '''
.gradio-container{max-width: 768px !important}
'''
with gr.Blocks(css=css) as demo:
prompt = gr.Textbox(label="Prompt")
run = gr.Button("Run")
with gr.Row():
image_turbo = gr.Image(label="SDXL Turbo")
image_lightning = gr.Image(label="SDXL Lightning")
image_hyper = gr.Image("Hyper SDXL")
run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])
|