BleachNick's picture
upload required packages
87d40d2
|
raw
history blame
10.3 kB

Text-to-image

text-to-image νŒŒμΈνŠœλ‹ μŠ€ν¬λ¦½νŠΈλŠ” experimental μƒνƒœμž…λ‹ˆλ‹€. κ³Όμ ν•©ν•˜κΈ° 쉽고 치λͺ…적인 망각과 같은 λ¬Έμ œμ— λΆ€λ”ͺ히기 μ‰½μŠ΅λ‹ˆλ‹€. 자체 λ°μ΄ν„°μ…‹μ—μ„œ μ΅œμƒμ˜ κ²°κ³Όλ₯Ό μ–»μœΌλ €λ©΄ λ‹€μ–‘ν•œ ν•˜μ΄νΌνŒŒλΌλ―Έν„°λ₯Ό νƒμƒ‰ν•˜λŠ” 것이 μ’‹μŠ΅λ‹ˆλ‹€.

Stable Diffusionκ³Ό 같은 text-to-image λͺ¨λΈμ€ ν…μŠ€νŠΈ ν”„λ‘¬ν”„νŠΈμ—μ„œ 이미지λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€. 이 κ°€μ΄λ“œλŠ” PyTorch 및 Flaxλ₯Ό μ‚¬μš©ν•˜μ—¬ 자체 λ°μ΄ν„°μ…‹μ—μ„œ CompVis/stable-diffusion-v1-4 λͺ¨λΈλ‘œ νŒŒμΈνŠœλ‹ν•˜λŠ” 방법을 λ³΄μ—¬μ€λ‹ˆλ‹€. 이 κ°€μ΄λ“œμ— μ‚¬μš©λœ text-to-image νŒŒμΈνŠœλ‹μ„ μœ„ν•œ λͺ¨λ“  ν•™μŠ΅ μŠ€ν¬λ¦½νŠΈμ— 관심이 μžˆλŠ” 경우 이 λ¦¬ν¬μ§€ν† λ¦¬μ—μ„œ μžμ„Ένžˆ 찾을 수 μžˆμŠ΅λ‹ˆλ‹€.

슀크립트λ₯Ό μ‹€ν–‰ν•˜κΈ° 전에, 라이브러리의 ν•™μŠ΅ dependency듀을 μ„€μΉ˜ν•΄μ•Ό ν•©λ‹ˆλ‹€:

pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt

그리고 πŸ€—Accelerate ν™˜κ²½μ„ μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€:

accelerate config

리포지토리λ₯Ό 이미 λ³΅μ œν•œ 경우, 이 단계λ₯Ό μˆ˜ν–‰ν•  ν•„μš”κ°€ μ—†μŠ΅λ‹ˆλ‹€. λŒ€μ‹ , 둜컬 체크아웃 경둜λ₯Ό ν•™μŠ΅ μŠ€ν¬λ¦½νŠΈμ— λͺ…μ‹œν•  수 있으며 κ±°κΈ°μ—μ„œ λ‘œλ“œλ©λ‹ˆλ‹€.

ν•˜λ“œμ›¨μ–΄ μš”κ΅¬ 사항

gradient_checkpointing 및 mixed_precision을 μ‚¬μš©ν•˜λ©΄ 단일 24GB GPUμ—μ„œ λͺ¨λΈμ„ νŒŒμΈνŠœλ‹ν•  수 μžˆμŠ΅λ‹ˆλ‹€. 더 높은 batch_size와 더 λΉ λ₯Έ ν›ˆλ ¨μ„ μœ„ν•΄μ„œλŠ” GPU λ©”λͺ¨λ¦¬κ°€ 30GB 이상인 GPUλ₯Ό μ‚¬μš©ν•˜λŠ” 것이 μ’‹μŠ΅λ‹ˆλ‹€. TPU λ˜λŠ” GPUμ—μ„œ νŒŒμΈνŠœλ‹μ„ μœ„ν•΄ JAXλ‚˜ Flaxλ₯Ό μ‚¬μš©ν•  μˆ˜λ„ μžˆμŠ΅λ‹ˆλ‹€. μžμ„Έν•œ λ‚΄μš©μ€ μ•„λž˜λ₯Ό μ°Έμ‘°ν•˜μ„Έμš”.

xFormers둜 memory efficient attention을 ν™œμ„±ν™”ν•˜μ—¬ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰ 훨씬 더 쀄일 수 μžˆμŠ΅λ‹ˆλ‹€. xFormersκ°€ μ„€μΉ˜λ˜μ–΄ μžˆλŠ”μ§€ ν™•μΈν•˜κ³  --enable_xformers_memory_efficient_attentionλ₯Ό ν•™μŠ΅ μŠ€ν¬λ¦½νŠΈμ— λͺ…μ‹œν•©λ‹ˆλ‹€.

xFormersλŠ” Flax에 μ‚¬μš©ν•  수 μ—†μŠ΅λ‹ˆλ‹€.

Hub에 λͺ¨λΈ μ—…λ‘œλ“œν•˜κΈ°

ν•™μŠ΅ μŠ€ν¬λ¦½νŠΈμ— λ‹€μŒ 인수λ₯Ό μΆ”κ°€ν•˜μ—¬ λͺ¨λΈμ„ ν—ˆλΈŒμ— μ €μž₯ν•©λ‹ˆλ‹€:

  --push_to_hub

체크포인트 μ €μž₯ 및 뢈러였기

ν•™μŠ΅ 쀑 λ°œμƒν•  수 μžˆλŠ” 일에 λŒ€λΉ„ν•˜μ—¬ μ •κΈ°μ μœΌλ‘œ 체크포인트λ₯Ό μ €μž₯ν•΄ λ‘λŠ” 것이 μ’‹μŠ΅λ‹ˆλ‹€. 체크포인트λ₯Ό μ €μž₯ν•˜λ €λ©΄ ν•™μŠ΅ μŠ€ν¬λ¦½νŠΈμ— λ‹€μŒ 인수λ₯Ό λͺ…μ‹œν•©λ‹ˆλ‹€.

  --checkpointing_steps=500

500μŠ€ν…λ§ˆλ‹€ 전체 ν•™μŠ΅ stateκ°€ 'output_dir'의 ν•˜μœ„ 폴더에 μ €μž₯λ©λ‹ˆλ‹€. μ²΄ν¬ν¬μΈνŠΈλŠ” 'checkpoint-'에 μ§€κΈˆκΉŒμ§€ ν•™μŠ΅λœ step μˆ˜μž…λ‹ˆλ‹€. 예λ₯Ό λ“€μ–΄ 'checkpoint-1500'은 1500 ν•™μŠ΅ step 후에 μ €μž₯된 μ²΄ν¬ν¬μΈνŠΈμž…λ‹ˆλ‹€.

ν•™μŠ΅μ„ μž¬κ°œν•˜κΈ° μœ„ν•΄ 체크포인트λ₯Ό 뢈러였렀면 '--resume_from_checkpoint' 인수λ₯Ό ν•™μŠ΅ μŠ€ν¬λ¦½νŠΈμ— λͺ…μ‹œν•˜κ³  μž¬κ°œν•  체크포인트λ₯Ό μ§€μ •ν•˜μ‹­μ‹œμ˜€. 예λ₯Ό λ“€μ–΄ λ‹€μŒ μΈμˆ˜λŠ” 1500개의 ν•™μŠ΅ step 후에 μ €μž₯된 μ²΄ν¬ν¬μΈνŠΈμ—μ„œλΆ€ν„° ν›ˆλ ¨μ„ μž¬κ°œν•©λ‹ˆλ‹€.

  --resume_from_checkpoint="checkpoint-1500"

νŒŒμΈνŠœλ‹

λ‹€μŒκ³Ό 같이 [Naruto BLIP μΊ‘μ…˜](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) λ°μ΄ν„°μ…‹μ—μ„œ νŒŒμΈνŠœλ‹ 싀행을 μœ„ν•΄ [PyTorch ν•™μŠ΅ 슀크립트](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)λ₯Ό μ‹€ν–‰ν•©λ‹ˆλ‹€:
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/naruto-blip-captions"

accelerate launch train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="sd-naruto-model" 

자체 λ°μ΄ν„°μ…‹μœΌλ‘œ νŒŒμΈνŠœλ‹ν•˜λ €λ©΄ πŸ€— Datasetsμ—μ„œ μš”κ΅¬ν•˜λŠ” ν˜•μ‹μ— 따라 데이터셋을 μ€€λΉ„ν•˜μ„Έμš”. 데이터셋을 ν—ˆλΈŒμ— μ—…λ‘œλ“œν•˜κ±°λ‚˜ [νŒŒμΌλ“€μ΄ μžˆλŠ” 둜컬 폴더λ₯Ό μ€€λΉ„](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

μ‚¬μš©μž μ»€μŠ€ν…€ loading logic을 μ‚¬μš©ν•˜λ €λ©΄ 슀크립트λ₯Ό μˆ˜μ •ν•˜μ‹­μ‹œμ˜€. 도움이 λ˜λ„λ‘ μ½”λ“œμ˜ μ μ ˆν•œ μœ„μΉ˜μ— 포인터λ₯Ό λ‚¨κ²ΌμŠ΅λ‹ˆλ‹€. πŸ€— μ•„λž˜ 예제 μŠ€ν¬λ¦½νŠΈλŠ” TRAIN_DIR의 둜컬 λ°μ΄ν„°μ…‹μœΌλ‘œλ₯Ό νŒŒμΈνŠœλ‹ν•˜λŠ” 방법과 OUTPUT_DIRμ—μ„œ λͺ¨λΈμ„ μ €μž₯ν•  μœ„μΉ˜λ₯Ό λ³΄μ—¬μ€λ‹ˆλ‹€:

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
export OUTPUT_DIR="path_to_save_model"

accelerate launch train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir=${OUTPUT_DIR}
[@duongna211](https://github.com/duongna21)의 κΈ°μ—¬λ‘œ, Flaxλ₯Ό μ‚¬μš©ν•΄ TPU 및 GPUμ—μ„œ Stable Diffusion λͺ¨λΈμ„ 더 λΉ λ₯΄κ²Œ ν•™μŠ΅ν•  수 μžˆμŠ΅λ‹ˆλ‹€. μ΄λŠ” TPU ν•˜λ“œμ›¨μ–΄μ—μ„œ 맀우 νš¨μœ¨μ μ΄μ§€λ§Œ GPUμ—μ„œλ„ ν›Œλ₯­ν•˜κ²Œ μž‘λ™ν•©λ‹ˆλ‹€. Flax ν•™μŠ΅ μŠ€ν¬λ¦½νŠΈλŠ” gradient checkpointingλ‚˜ gradient accumulationκ³Ό 같은 κΈ°λŠ₯을 아직 μ§€μ›ν•˜μ§€ μ•ŠμœΌλ―€λ‘œ λ©”λͺ¨λ¦¬κ°€ 30GB 이상인 GPU λ˜λŠ” TPU v3κ°€ ν•„μš”ν•©λ‹ˆλ‹€.

슀크립트λ₯Ό μ‹€ν–‰ν•˜κΈ° 전에 μš”κ΅¬ 사항이 μ„€μΉ˜λ˜μ–΄ μžˆλŠ”μ§€ ν™•μΈν•˜μ‹­μ‹œμ˜€:

pip install -U -r requirements_flax.txt

그러면 λ‹€μŒκ³Ό 같이 Flax ν•™μŠ΅ 슀크립트λ₯Ό μ‹€ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export dataset_name="lambdalabs/naruto-blip-captions"

python train_text_to_image_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --output_dir="sd-naruto-model" 

자체 λ°μ΄ν„°μ…‹μœΌλ‘œ νŒŒμΈνŠœλ‹ν•˜λ €λ©΄ πŸ€— Datasetsμ—μ„œ μš”κ΅¬ν•˜λŠ” ν˜•μ‹μ— 따라 데이터셋을 μ€€λΉ„ν•˜μ„Έμš”. 데이터셋을 ν—ˆλΈŒμ— μ—…λ‘œλ“œν•˜κ±°λ‚˜ [νŒŒμΌλ“€μ΄ μžˆλŠ” 둜컬 폴더λ₯Ό μ€€λΉ„](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

μ‚¬μš©μž μ»€μŠ€ν…€ loading logic을 μ‚¬μš©ν•˜λ €λ©΄ 슀크립트λ₯Ό μˆ˜μ •ν•˜μ‹­μ‹œμ˜€. 도움이 λ˜λ„λ‘ μ½”λ“œμ˜ μ μ ˆν•œ μœ„μΉ˜μ— 포인터λ₯Ό λ‚¨κ²ΌμŠ΅λ‹ˆλ‹€. πŸ€— μ•„λž˜ 예제 μŠ€ν¬λ¦½νŠΈλŠ” TRAIN_DIR의 둜컬 λ°μ΄ν„°μ…‹μœΌλ‘œλ₯Ό νŒŒμΈνŠœλ‹ν•˜λŠ” 방법을 λ³΄μ—¬μ€λ‹ˆλ‹€:

export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset"

python train_text_to_image_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --output_dir="sd-naruto-model"

LoRA

Text-to-image λͺ¨λΈ νŒŒμΈνŠœλ‹μ„ μœ„ν•΄, λŒ€κ·œλͺ¨ λͺ¨λΈ ν•™μŠ΅μ„ κ°€μ†ν™”ν•˜κΈ° μœ„ν•œ νŒŒμΈνŠœλ‹ 기술인 LoRA(Low-Rank Adaptation of Large Language Models)λ₯Ό μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€. μžμ„Έν•œ λ‚΄μš©μ€ LoRA ν•™μŠ΅ κ°€μ΄λ“œλ₯Ό μ°Έμ‘°ν•˜μ„Έμš”.

μΆ”λ‘ 

ν—ˆλΈŒμ˜ λͺ¨λΈ 경둜 λ˜λŠ” λͺ¨λΈ 이름을 [StableDiffusionPipeline]에 μ „λ‹¬ν•˜μ—¬ 좔둠을 μœ„ν•΄ 파인 νŠœλ‹λœ λͺ¨λΈμ„ 뢈러올 수 μžˆμŠ΅λ‹ˆλ‹€:

```python from diffusers import StableDiffusionPipeline

model_path = "path_to_saved_model" pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) pipe.to("cuda")

image = pipe(prompt="yoda").images[0] image.save("yoda-naruto.png")

</pt>
<jax>
```python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline

model_path = "path_to_saved_model"
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)

prompt = "yoda naruto"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
image.save("yoda-naruto.png")