NadaGh's picture
End of training
dde5d93 verified

Stable Diffusion XL Turbo

[[open-in-colab]]

SDXL Turbo๋Š” adversarial time-distilled(์ ๋Œ€์  ์‹œ๊ฐ„ ์ „์ด) Stable Diffusion XL(SDXL) ๋ชจ๋ธ๋กœ, ๋‹จ ํ•œ ๋ฒˆ์˜ ์Šคํ…๋งŒ์œผ๋กœ ์ถ”๋ก ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” text-to-image์™€ image-to-image๋ฅผ ์œ„ํ•œ SDXL-Turbo๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ๋‹ค์Œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

# Colab์—์„œ ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ•˜๊ธฐ ์œ„ํ•ด ์ฃผ์„์„ ์ œ์™ธํ•˜์„ธ์š”
#!pip install -q diffusers transformers accelerate

๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋Š” Hub์˜ ๋ณ„๋„ ํ•˜์œ„ ํด๋” ๋˜๋Š” ๋กœ์ปฌ์— ์ €์žฅํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด ๊ฒฝ์šฐ [~StableDiffusionXLPipeline.from_pretrained] ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipeline = pipeline.to("cuda")

๋˜ํ•œ [~StableDiffusionXLPipeline.from_single_file] ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ—ˆ๋ธŒ ๋˜๋Š” ๋กœ์ปฌ์—์„œ ๋‹จ์ผ ํŒŒ์ผ ํ˜•์‹(.ckpt ๋˜๋Š” .safetensors)์œผ๋กœ ์ €์žฅ๋œ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:

from diffusers import StableDiffusionXLPipeline
import torch

pipeline = StableDiffusionXLPipeline.from_single_file(
    "https://huggingface.co/stabilityai/sdxl-turbo/blob/main/sd_xl_turbo_1.0_fp16.safetensors", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

Text-to-image

Text-to-image์˜ ๊ฒฝ์šฐ ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ๊ธฐ๋ณธ์ ์œผ๋กœ SDXL Turbo๋Š” 512x512 ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋ฉฐ, ์ด ํ•ด์ƒ๋„์—์„œ ์ตœ์ƒ์˜ ๊ฒฐ๊ณผ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. height ๋ฐ width ๋งค๊ฐœ ๋ณ€์ˆ˜๋ฅผ 768x768 ๋˜๋Š” 1024x1024๋กœ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ ์ด ๊ฒฝ์šฐ ํ’ˆ์งˆ ์ €ํ•˜๋ฅผ ์˜ˆ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ์ด guidance_scale ์—†์ด ํ•™์Šต๋˜์—ˆ์œผ๋ฏ€๋กœ ์ด๋ฅผ 0.0์œผ๋กœ ์„ค์ •ํ•ด ๋น„ํ™œ์„ฑํ™”ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋‹จ์ผ ์ถ”๋ก  ์Šคํ…๋งŒ์œผ๋กœ๋„ ๊ณ ํ’ˆ์งˆ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์Šคํ… ์ˆ˜๋ฅผ 2, 3 ๋˜๋Š” 4๋กœ ๋Š˜๋ฆฌ๋ฉด ์ด๋ฏธ์ง€ ํ’ˆ์งˆ์ด ํ–ฅ์ƒ๋ฉ๋‹ˆ๋‹ค.

from diffusers import AutoPipelineForText2Image
import torch

pipeline_text2image = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipeline_text2image = pipeline_text2image.to("cuda")

prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."

image = pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]
image
generated image of a racoon in a robe

Image-to-image

Image-to-image ์ƒ์„ฑ์˜ ๊ฒฝ์šฐ num_inference_steps * strength๊ฐ€ 1๋ณด๋‹ค ํฌ๊ฑฐ๋‚˜ ๊ฐ™์€์ง€ ํ™•์ธํ•˜์„ธ์š”. Image-to-image ํŒŒ์ดํ”„๋ผ์ธ์€ ์•„๋ž˜ ์˜ˆ์ œ์—์„œ 0.5 * 2.0 = 1 ์Šคํ…๊ณผ ๊ฐ™์ด int(num_inference_steps * strength) ์Šคํ…์œผ๋กœ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.

from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image, make_image_grid

# ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ ์ถ”๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ ์†Œ๋ชจ๋ฅผ ํ”ผํ•˜๋ ค๋ฉด from_pipe๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”.
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to("cuda")

init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
init_image = init_image.resize((512, 512))

prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"

image = pipeline(prompt, image=init_image, strength=0.5, guidance_scale=0.0, num_inference_steps=2).images[0]
make_image_grid([init_image, image], rows=1, cols=2)
Image-to-image generation sample using SDXL Turbo

SDXL Turbo ์†๋„ ํ›จ์”ฌ ๋” ๋น ๋ฅด๊ฒŒ ํ•˜๊ธฐ

  • PyTorch ๋ฒ„์ „ 2 ์ด์ƒ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ UNet์„ ์ปดํŒŒ์ผํ•ฉ๋‹ˆ๋‹ค. ์ฒซ ๋ฒˆ์งธ ์ถ”๋ก  ์‹คํ–‰์€ ๋งค์šฐ ๋Š๋ฆฌ์ง€๋งŒ ์ดํ›„ ์‹คํ–‰์€ ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
  • ๊ธฐ๋ณธ VAE๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ, ๊ฐ ์ƒ์„ฑ ์ „ํ›„์— ๋น„์šฉ์ด ๋งŽ์ด ๋“œ๋Š” dtype ๋ณ€ํ™˜์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด float32๋กœ ์œ ์ง€ํ•˜์„ธ์š”. ์ด ์ž‘์—…์€ ์ฒซ ์ƒ์„ฑ ์ด์ „์— ํ•œ ๋ฒˆ๋งŒ ์ˆ˜ํ–‰ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค:
pipe.upcast_vae()

๋˜๋Š”, ์ปค๋ฎค๋‹ˆํ‹ฐ ํšŒ์›์ธ @madebyollin์ด ๋งŒ๋“  16๋น„ํŠธ VAE๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์œผ๋ฉฐ, ์ด๋Š” float32๋กœ ์—…์บ์ŠคํŠธํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.