Stable Diffusion XL Turbo
[[open-in-colab]]
SDXL Turbo๋ adversarial time-distilled(์ ๋์ ์๊ฐ ์ ์ด) Stable Diffusion XL(SDXL) ๋ชจ๋ธ๋ก, ๋จ ํ ๋ฒ์ ์คํ ๋ง์ผ๋ก ์ถ๋ก ์ ์คํํ ์ ์์ต๋๋ค.
์ด ๊ฐ์ด๋์์๋ text-to-image์ image-to-image๋ฅผ ์ํ SDXL-Turbo๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค.
์์ํ๊ธฐ ์ ์ ๋ค์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์:
# Colab์์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํ๊ธฐ ์ํด ์ฃผ์์ ์ ์ธํ์ธ์
#!pip install -q diffusers transformers accelerate
๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ ๋ถ๋ฌ์ค๊ธฐ
๋ชจ๋ธ ๊ฐ์ค์น๋ Hub์ ๋ณ๋ ํ์ ํด๋ ๋๋ ๋ก์ปฌ์ ์ ์ฅํ ์ ์์ผ๋ฉฐ, ์ด ๊ฒฝ์ฐ [~StableDiffusionXLPipeline.from_pretrained
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค:
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipeline = pipeline.to("cuda")
๋ํ [~StableDiffusionXLPipeline.from_single_file
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ธ ๋๋ ๋ก์ปฌ์์ ๋จ์ผ ํ์ผ ํ์(.ckpt
๋๋ .safetensors
)์ผ๋ก ์ ์ฅ๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ฌ ์๋ ์์ต๋๋ค:
from diffusers import StableDiffusionXLPipeline
import torch
pipeline = StableDiffusionXLPipeline.from_single_file(
"https://huggingface.co/stabilityai/sdxl-turbo/blob/main/sd_xl_turbo_1.0_fp16.safetensors", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")
Text-to-image
Text-to-image์ ๊ฒฝ์ฐ ํ
์คํธ ํ๋กฌํํธ๋ฅผ ์ ๋ฌํฉ๋๋ค. ๊ธฐ๋ณธ์ ์ผ๋ก SDXL Turbo๋ 512x512 ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ฉฐ, ์ด ํด์๋์์ ์ต์์ ๊ฒฐ๊ณผ๋ฅผ ์ ๊ณตํฉ๋๋ค. height
๋ฐ width
๋งค๊ฐ ๋ณ์๋ฅผ 768x768 ๋๋ 1024x1024๋ก ์ค์ ํ ์ ์์ง๋ง ์ด ๊ฒฝ์ฐ ํ์ง ์ ํ๋ฅผ ์์ํ ์ ์์ต๋๋ค.
๋ชจ๋ธ์ด guidance_scale
์์ด ํ์ต๋์์ผ๋ฏ๋ก ์ด๋ฅผ 0.0์ผ๋ก ์ค์ ํด ๋นํ์ฑํํด์ผ ํฉ๋๋ค. ๋จ์ผ ์ถ๋ก ์คํ
๋ง์ผ๋ก๋ ๊ณ ํ์ง ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ ์์ต๋๋ค.
์คํ
์๋ฅผ 2, 3 ๋๋ 4๋ก ๋๋ฆฌ๋ฉด ์ด๋ฏธ์ง ํ์ง์ด ํฅ์๋ฉ๋๋ค.
from diffusers import AutoPipelineForText2Image
import torch
pipeline_text2image = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipeline_text2image = pipeline_text2image.to("cuda")
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
image = pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]
image

Image-to-image
Image-to-image ์์ฑ์ ๊ฒฝ์ฐ num_inference_steps * strength
๊ฐ 1๋ณด๋ค ํฌ๊ฑฐ๋ ๊ฐ์์ง ํ์ธํ์ธ์.
Image-to-image ํ์ดํ๋ผ์ธ์ ์๋ ์์ ์์ 0.5 * 2.0 = 1
์คํ
๊ณผ ๊ฐ์ด int(num_inference_steps * strength)
์คํ
์ผ๋ก ์คํ๋ฉ๋๋ค.
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image, make_image_grid
# ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ฌ ๋ ์ถ๊ฐ ๋ฉ๋ชจ๋ฆฌ ์๋ชจ๋ฅผ ํผํ๋ ค๋ฉด from_pipe๋ฅผ ์ฌ์ฉํ์ธ์.
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to("cuda")
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
init_image = init_image.resize((512, 512))
prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
image = pipeline(prompt, image=init_image, strength=0.5, guidance_scale=0.0, num_inference_steps=2).images[0]
make_image_grid([init_image, image], rows=1, cols=2)

SDXL Turbo ์๋ ํจ์ฌ ๋ ๋น ๋ฅด๊ฒ ํ๊ธฐ
- PyTorch ๋ฒ์ 2 ์ด์์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ UNet์ ์ปดํ์ผํฉ๋๋ค. ์ฒซ ๋ฒ์งธ ์ถ๋ก ์คํ์ ๋งค์ฐ ๋๋ฆฌ์ง๋ง ์ดํ ์คํ์ ํจ์ฌ ๋นจ๋ผ์ง๋๋ค.
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- ๊ธฐ๋ณธ VAE๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ, ๊ฐ ์์ฑ ์ ํ์ ๋น์ฉ์ด ๋ง์ด ๋๋
dtype
๋ณํ์ ํผํ๊ธฐ ์ํดfloat32
๋ก ์ ์งํ์ธ์. ์ด ์์ ์ ์ฒซ ์์ฑ ์ด์ ์ ํ ๋ฒ๋ง ์ํํ๋ฉด ๋ฉ๋๋ค:
pipe.upcast_vae()
๋๋, ์ปค๋ฎค๋ํฐ ํ์์ธ @madebyollin
์ด ๋ง๋ 16๋นํธ VAE๋ฅผ ์ฌ์ฉํ ์๋ ์์ผ๋ฉฐ, ์ด๋ float32
๋ก ์
์บ์คํธํ ํ์๊ฐ ์์ต๋๋ค.