cocktailpeanut's picture
localize
49901fa
raw
history blame
6.31 kB
import gradio as gr
import torch
import devicetorch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LCMScheduler
from diffusers.schedulers import TCDScheduler
import spaces
from PIL import Image
checkpoints = {
"2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0],
"4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0],
"8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0],
"16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0],
"Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5],
"Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5],
"Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5],
"LCM-Like LoRA": [
"pcm_{}_lcmlike_lora_converted.safetensors",
4,
0.0,
],
}
loaded = None
device = devicetorch.get(torch)
#if torch.cuda.is_available():
# pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-1.0",
# torch_dtype=torch.float16,
# variant="fp16",
# ).to("cuda")
# pipe_sd15 = StableDiffusionPipeline.from_pretrained(
# "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
# ).to("cuda")
pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
).to(device)
pipe_sd15 = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
).to(device)
@spaces.GPU(enable_queue=True)
def generate_image(
prompt,
ckpt,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
mode="sdxl",
):
global loaded
checkpoint = checkpoints[ckpt][0].format(mode)
guidance_scale = checkpoints[ckpt][2]
pipe = pipe_sdxl if mode == "sdxl" else pipe_sd15
if loaded != (ckpt + mode):
pipe.load_lora_weights(
"wangfuyun/PCM_Weights", weight_name=checkpoint, subfolder=mode
)
loaded = ckpt + mode
if ckpt == "LCM-Like LoRA":
pipe.scheduler = LCMScheduler()
else:
pipe.scheduler = TCDScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
timestep_spacing="trailing",
)
results = pipe(
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
)
if SAFETY_CHECKER:
images, has_nsfw_concepts = check_nsfw_images(results.images)
if any(has_nsfw_concepts):
gr.Warning("NSFW content detected.")
return Image.new("RGB", (512, 512))
return images[0]
return results.images[0]
def update_steps(ckpt):
num_inference_steps = checkpoints[ckpt][1]
if ckpt == "LCM-Like LoRA":
return gr.update(interactive=True, value=num_inference_steps)
return gr.update(interactive=False, value=num_inference_steps)
css = """
.gradio-container {
max-width: 60rem !important;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# Phased Consistency Model
Phased Consistency Model (PCM) is an image generation technique that addresses the limitations of the Latent Consistency Model (LCM) in high-resolution and text-conditioned image generation.
PCM outperforms LCM across various generation settings and achieves state-of-the-art results in both image and video generation.
[[paper](https://huggingface.co/papers/2405.18407)] [[arXiv](https://arxiv.org/abs/2405.18407)] [[code](https://github.com/G-U-N/Phased-Consistency-Model)] [[project page](https://g-u-n.github.io/projects/pcm)]
"""
)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label="Prompt", scale=8)
ckpt = gr.Dropdown(
label="Select inference steps",
choices=list(checkpoints.keys()),
value="4-Step",
)
steps = gr.Slider(
label="Number of Inference Steps",
minimum=1,
maximum=20,
step=1,
value=4,
interactive=False,
)
ckpt.change(
fn=update_steps,
inputs=[ckpt],
outputs=[steps],
queue=False,
show_progress=False,
)
submit_sdxl = gr.Button("Run on SDXL", scale=1)
submit_sd15 = gr.Button("Run on SD15", scale=1)
img = gr.Image(label="PCM Image")
gr.Examples(
examples=[
[" astronaut walking on the moon", "4-Step", 4],
[
"Photo of a dramatic cliffside lighthouse in a storm, waves crashing, symbol of guidance and resilience.",
"8-Step",
8,
],
[
"Vincent vangogh style, painting, a boy, clouds in the sky",
"Normal CFG 4-Step",
4,
],
[
"Echoes of a forgotten song drift across the moonlit sea, where a ghost ship sails, its spectral crew bound to an eternal quest for redemption.",
"4-Step",
4,
],
[
"Roger rabbit as a real person, photorealistic, cinematic.",
"16-Step",
16,
],
[
"tanding tall amidst the ruins, a stone golem awakens, vines and flowers sprouting from the crevices in its body.",
"LCM-Like LoRA",
4,
],
],
inputs=[prompt, ckpt, steps],
outputs=[img],
fn=generate_image,
cache_examples="lazy",
)
gr.on(
fn=generate_image,
triggers=[ckpt.change, prompt.submit, submit_sdxl.click],
inputs=[prompt, ckpt, steps],
outputs=[img],
)
gr.on(
fn=lambda *args: generate_image(*args, mode="sd15"),
triggers=[submit_sd15.click],
inputs=[prompt, ckpt, steps],
outputs=[img],
)
demo.queue(api_open=False).launch(show_api=False)