File size: 2,529 Bytes
c2c42ca
 
 
 
 
 
 
 
 
4ac12a1
c2c42ca
 
 
 
 
 
4ac12a1
3bd17ee
4ac12a1
e10dc6d
c2c42ca
4ac12a1
c2c42ca
 
 
 
 
4ac12a1
3bd17ee
4ac12a1
c2c42ca
4ac12a1
e10dc6d
143f063
 
34cb1b5
4ac12a1
c2c42ca
4ac12a1
 
c2c42ca
4ac12a1
 
c2c42ca
4ac12a1
c2c42ca
143f063
 
c2c42ca
 
 
 
 
 
 
 
 
 
34cb1b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces

### SDXL Turbo #### 
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
#pipe_turbo.to("cuda")

### SDXL Lightning ### 
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors" 

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
#pipe_lightning.to("cuda")

### Hyper SDXL ### 
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
#pipe_hyper.to("cuda")
del unet

@spaces.GPU
def run_comparison(prompt):
    image_turbo.to("cuda")
    image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_turbo.to("cpu")
    image_lightning.to("cuda")
    image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_lightning.to("cpu")
    image_hyper.to("cuda")
    image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
    image_turbo.to("cpu")
    return image_turbo, image_lightning, image_hyper


css = '''
.gradio-container{max-width: 768px !important}
'''
with gr.Blocks(css=css) as demo:
    prompt = gr.Textbox(label="Prompt")
    run = gr.Button("Run")
    with gr.Row():
        image_turbo = gr.Image(label="SDXL Turbo")
        image_lightning = gr.Image(label="SDXL Lightning")
        image_hyper = gr.Image("Hyper SDXL")
    
    run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])