|
<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
--> |
|
|
|
|
|
# Text-to-image |
|
|
|
<Tip warning={true}> |
|
|
|
text-to-image ํ์ธํ๋ ์คํฌ๋ฆฝํธ๋ experimental ์ํ์
๋๋ค. ๊ณผ์ ํฉํ๊ธฐ ์ฝ๊ณ ์น๋ช
์ ์ธ ๋ง๊ฐ๊ณผ ๊ฐ์ ๋ฌธ์ ์ ๋ถ๋ชํ๊ธฐ ์ฝ์ต๋๋ค. ์์ฒด ๋ฐ์ดํฐ์
์์ ์ต์์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ผ๋ ค๋ฉด ๋ค์ํ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํ์ํ๋ ๊ฒ์ด ์ข์ต๋๋ค. |
|
|
|
</Tip> |
|
|
|
Stable Diffusion๊ณผ ๊ฐ์ text-to-image ๋ชจ๋ธ์ ํ
์คํธ ํ๋กฌํํธ์์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค. ์ด ๊ฐ์ด๋๋ PyTorch ๋ฐ Flax๋ฅผ ์ฌ์ฉํ์ฌ ์์ฒด ๋ฐ์ดํฐ์
์์ [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) ๋ชจ๋ธ๋ก ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. ์ด ๊ฐ์ด๋์ ์ฌ์ฉ๋ text-to-image ํ์ธํ๋์ ์ํ ๋ชจ๋ ํ์ต ์คํฌ๋ฆฝํธ์ ๊ด์ฌ์ด ์๋ ๊ฒฝ์ฐ ์ด [๋ฆฌํฌ์งํ ๋ฆฌ](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)์์ ์์ธํ ์ฐพ์ ์ ์์ต๋๋ค. |
|
|
|
์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์, ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํ์ต dependency๋ค์ ์ค์นํด์ผ ํฉ๋๋ค: |
|
|
|
```bash |
|
pip install git+https://github.com/huggingface/diffusers.git |
|
pip install -U -r requirements.txt |
|
``` |
|
|
|
๊ทธ๋ฆฌ๊ณ [๐คAccelerate](https://github.com/huggingface/accelerate/) ํ๊ฒฝ์ ์ด๊ธฐํํฉ๋๋ค: |
|
|
|
```bash |
|
accelerate config |
|
``` |
|
|
|
๋ฆฌํฌ์งํ ๋ฆฌ๋ฅผ ์ด๋ฏธ ๋ณต์ ํ ๊ฒฝ์ฐ, ์ด ๋จ๊ณ๋ฅผ ์ํํ ํ์๊ฐ ์์ต๋๋ค. ๋์ , ๋ก์ปฌ ์ฒดํฌ์์ ๊ฒฝ๋ก๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช
์ํ ์ ์์ผ๋ฉฐ ๊ฑฐ๊ธฐ์์ ๋ก๋๋ฉ๋๋ค. |
|
|
|
### ํ๋์จ์ด ์๊ตฌ ์ฌํญ |
|
|
|
`gradient_checkpointing` ๋ฐ `mixed_precision`์ ์ฌ์ฉํ๋ฉด ๋จ์ผ 24GB GPU์์ ๋ชจ๋ธ์ ํ์ธํ๋ํ ์ ์์ต๋๋ค. ๋ ๋์ `batch_size`์ ๋ ๋น ๋ฅธ ํ๋ จ์ ์ํด์๋ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ 30GB ์ด์์ธ GPU๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. TPU ๋๋ GPU์์ ํ์ธํ๋์ ์ํด JAX๋ Flax๋ฅผ ์ฌ์ฉํ ์๋ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ [์๋](#flax-jax-finetuning)๋ฅผ ์ฐธ์กฐํ์ธ์. |
|
|
|
xFormers๋ก memory efficient attention์ ํ์ฑํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ํจ์ฌ ๋ ์ค์ผ ์ ์์ต๋๋ค. [xFormers๊ฐ ์ค์น](./optimization/xformers)๋์ด ์๋์ง ํ์ธํ๊ณ `--enable_xformers_memory_efficient_attention`๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช
์ํฉ๋๋ค. |
|
|
|
xFormers๋ Flax์ ์ฌ์ฉํ ์ ์์ต๋๋ค. |
|
|
|
## Hub์ ๋ชจ๋ธ ์
๋ก๋ํ๊ธฐ |
|
|
|
ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ์ถ๊ฐํ์ฌ ๋ชจ๋ธ์ ํ๋ธ์ ์ ์ฅํฉ๋๋ค: |
|
|
|
```bash |
|
--push_to_hub |
|
``` |
|
|
|
|
|
## ์ฒดํฌํฌ์ธํธ ์ ์ฅ ๋ฐ ๋ถ๋ฌ์ค๊ธฐ |
|
|
|
ํ์ต ์ค ๋ฐ์ํ ์ ์๋ ์ผ์ ๋๋นํ์ฌ ์ ๊ธฐ์ ์ผ๋ก ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํด ๋๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ๋ ค๋ฉด ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ๋ช
์ํฉ๋๋ค. |
|
|
|
```bash |
|
--checkpointing_steps=500 |
|
``` |
|
|
|
500์คํ
๋ง๋ค ์ ์ฒด ํ์ต state๊ฐ 'output_dir'์ ํ์ ํด๋์ ์ ์ฅ๋ฉ๋๋ค. ์ฒดํฌํฌ์ธํธ๋ 'checkpoint-'์ ์ง๊ธ๊น์ง ํ์ต๋ step ์์
๋๋ค. ์๋ฅผ ๋ค์ด 'checkpoint-1500'์ 1500 ํ์ต step ํ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์
๋๋ค. |
|
|
|
ํ์ต์ ์ฌ๊ฐํ๊ธฐ ์ํด ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ค๋ ค๋ฉด '--resume_from_checkpoint' ์ธ์๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช
์ํ๊ณ ์ฌ๊ฐํ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ง์ ํ์ญ์์ค. ์๋ฅผ ๋ค์ด ๋ค์ ์ธ์๋ 1500๊ฐ์ ํ์ต step ํ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์์๋ถํฐ ํ๋ จ์ ์ฌ๊ฐํฉ๋๋ค. |
|
|
|
```bash |
|
--resume_from_checkpoint="checkpoint-1500" |
|
``` |
|
|
|
## ํ์ธํ๋ |
|
|
|
<frameworkcontent> |
|
<pt> |
|
๋ค์๊ณผ ๊ฐ์ด [Naruto BLIP ์บก์
](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) ๋ฐ์ดํฐ์
์์ ํ์ธํ๋ ์คํ์ ์ํด [PyTorch ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)๋ฅผ ์คํํฉ๋๋ค: |
|
|
|
|
|
```bash |
|
export MODEL_NAME="CompVis/stable-diffusion-v1-4" |
|
export dataset_name="lambdalabs/naruto-blip-captions" |
|
|
|
accelerate launch train_text_to_image.py \ |
|
--pretrained_model_name_or_path=$MODEL_NAME \ |
|
--dataset_name=$dataset_name \ |
|
--use_ema \ |
|
--resolution=512 --center_crop --random_flip \ |
|
--train_batch_size=1 \ |
|
--gradient_accumulation_steps=4 \ |
|
--gradient_checkpointing \ |
|
--mixed_precision="fp16" \ |
|
--max_train_steps=15000 \ |
|
--learning_rate=1e-05 \ |
|
--max_grad_norm=1 \ |
|
--lr_scheduler="constant" --lr_warmup_steps=0 \ |
|
--output_dir="sd-naruto-model" |
|
``` |
|
|
|
์์ฒด ๋ฐ์ดํฐ์
์ผ๋ก ํ์ธํ๋ํ๋ ค๋ฉด ๐ค [Datasets](https://huggingface.co/docs/datasets/index)์์ ์๊ตฌํ๋ ํ์์ ๋ฐ๋ผ ๋ฐ์ดํฐ์
์ ์ค๋นํ์ธ์. [๋ฐ์ดํฐ์
์ ํ๋ธ์ ์
๋ก๋](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)ํ๊ฑฐ๋ [ํ์ผ๋ค์ด ์๋ ๋ก์ปฌ ํด๋๋ฅผ ์ค๋น](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ ์ ์์ต๋๋ค. |
|
|
|
์ฌ์ฉ์ ์ปค์คํ
loading logic์ ์ฌ์ฉํ๋ ค๋ฉด ์คํฌ๋ฆฝํธ๋ฅผ ์์ ํ์ญ์์ค. ๋์์ด ๋๋๋ก ์ฝ๋์ ์ ์ ํ ์์น์ ํฌ์ธํฐ๋ฅผ ๋จ๊ฒผ์ต๋๋ค. ๐ค ์๋ ์์ ์คํฌ๋ฆฝํธ๋ `TRAIN_DIR`์ ๋ก์ปฌ ๋ฐ์ดํฐ์
์ผ๋ก๋ฅผ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ๊ณผ `OUTPUT_DIR`์์ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ๋ณด์ฌ์ค๋๋ค: |
|
|
|
|
|
```bash |
|
export MODEL_NAME="CompVis/stable-diffusion-v1-4" |
|
export TRAIN_DIR="path_to_your_dataset" |
|
export OUTPUT_DIR="path_to_save_model" |
|
|
|
accelerate launch train_text_to_image.py \ |
|
--pretrained_model_name_or_path=$MODEL_NAME \ |
|
--train_data_dir=$TRAIN_DIR \ |
|
--use_ema \ |
|
--resolution=512 --center_crop --random_flip \ |
|
--train_batch_size=1 \ |
|
--gradient_accumulation_steps=4 \ |
|
--gradient_checkpointing \ |
|
--mixed_precision="fp16" \ |
|
--max_train_steps=15000 \ |
|
--learning_rate=1e-05 \ |
|
--max_grad_norm=1 \ |
|
--lr_scheduler="constant" --lr_warmup_steps=0 \ |
|
--output_dir=${OUTPUT_DIR} |
|
``` |
|
|
|
</pt> |
|
<jax> |
|
[@duongna211](https://github.com/duongna21)์ ๊ธฐ์ฌ๋ก, Flax๋ฅผ ์ฌ์ฉํด TPU ๋ฐ GPU์์ Stable Diffusion ๋ชจ๋ธ์ ๋ ๋น ๋ฅด๊ฒ ํ์ตํ ์ ์์ต๋๋ค. ์ด๋ TPU ํ๋์จ์ด์์ ๋งค์ฐ ํจ์จ์ ์ด์ง๋ง GPU์์๋ ํ๋ฅญํ๊ฒ ์๋ํฉ๋๋ค. Flax ํ์ต ์คํฌ๋ฆฝํธ๋ gradient checkpointing๋ gradient accumulation๊ณผ ๊ฐ์ ๊ธฐ๋ฅ์ ์์ง ์ง์ํ์ง ์์ผ๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ๊ฐ 30GB ์ด์์ธ GPU ๋๋ TPU v3๊ฐ ํ์ํฉ๋๋ค. |
|
|
|
์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์ ์๊ตฌ ์ฌํญ์ด ์ค์น๋์ด ์๋์ง ํ์ธํ์ญ์์ค: |
|
|
|
```bash |
|
pip install -U -r requirements_flax.txt |
|
``` |
|
|
|
๊ทธ๋ฌ๋ฉด ๋ค์๊ณผ ๊ฐ์ด [Flax ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py)๋ฅผ ์คํํ ์ ์์ต๋๋ค. |
|
|
|
```bash |
|
export MODEL_NAME="runwayml/stable-diffusion-v1-5" |
|
export dataset_name="lambdalabs/naruto-blip-captions" |
|
|
|
python train_text_to_image_flax.py \ |
|
--pretrained_model_name_or_path=$MODEL_NAME \ |
|
--dataset_name=$dataset_name \ |
|
--resolution=512 --center_crop --random_flip \ |
|
--train_batch_size=1 \ |
|
--max_train_steps=15000 \ |
|
--learning_rate=1e-05 \ |
|
--max_grad_norm=1 \ |
|
--output_dir="sd-naruto-model" |
|
``` |
|
|
|
์์ฒด ๋ฐ์ดํฐ์
์ผ๋ก ํ์ธํ๋ํ๋ ค๋ฉด ๐ค [Datasets](https://huggingface.co/docs/datasets/index)์์ ์๊ตฌํ๋ ํ์์ ๋ฐ๋ผ ๋ฐ์ดํฐ์
์ ์ค๋นํ์ธ์. [๋ฐ์ดํฐ์
์ ํ๋ธ์ ์
๋ก๋](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)ํ๊ฑฐ๋ [ํ์ผ๋ค์ด ์๋ ๋ก์ปฌ ํด๋๋ฅผ ์ค๋น](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ ์ ์์ต๋๋ค. |
|
|
|
์ฌ์ฉ์ ์ปค์คํ
loading logic์ ์ฌ์ฉํ๋ ค๋ฉด ์คํฌ๋ฆฝํธ๋ฅผ ์์ ํ์ญ์์ค. ๋์์ด ๋๋๋ก ์ฝ๋์ ์ ์ ํ ์์น์ ํฌ์ธํฐ๋ฅผ ๋จ๊ฒผ์ต๋๋ค. ๐ค ์๋ ์์ ์คํฌ๋ฆฝํธ๋ `TRAIN_DIR`์ ๋ก์ปฌ ๋ฐ์ดํฐ์
์ผ๋ก๋ฅผ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค: |
|
|
|
```bash |
|
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" |
|
export TRAIN_DIR="path_to_your_dataset" |
|
|
|
python train_text_to_image_flax.py \ |
|
--pretrained_model_name_or_path=$MODEL_NAME \ |
|
--train_data_dir=$TRAIN_DIR \ |
|
--resolution=512 --center_crop --random_flip \ |
|
--train_batch_size=1 \ |
|
--mixed_precision="fp16" \ |
|
--max_train_steps=15000 \ |
|
--learning_rate=1e-05 \ |
|
--max_grad_norm=1 \ |
|
--output_dir="sd-naruto-model" |
|
``` |
|
</jax> |
|
</frameworkcontent> |
|
|
|
## LoRA |
|
|
|
Text-to-image ๋ชจ๋ธ ํ์ธํ๋์ ์ํด, ๋๊ท๋ชจ ๋ชจ๋ธ ํ์ต์ ๊ฐ์ํํ๊ธฐ ์ํ ํ์ธํ๋ ๊ธฐ์ ์ธ LoRA(Low-Rank Adaptation of Large Language Models)๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ [LoRA ํ์ต](lora#text-to-image) ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํ์ธ์. |
|
|
|
## ์ถ๋ก |
|
|
|
ํ๋ธ์ ๋ชจ๋ธ ๊ฒฝ๋ก ๋๋ ๋ชจ๋ธ ์ด๋ฆ์ [`StableDiffusionPipeline`]์ ์ ๋ฌํ์ฌ ์ถ๋ก ์ ์ํด ํ์ธ ํ๋๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค: |
|
|
|
<frameworkcontent> |
|
<pt> |
|
```python |
|
from diffusers import StableDiffusionPipeline |
|
|
|
model_path = "path_to_saved_model" |
|
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) |
|
pipe.to("cuda") |
|
|
|
image = pipe(prompt="yoda").images[0] |
|
image.save("yoda-naruto.png") |
|
``` |
|
</pt> |
|
<jax> |
|
```python |
|
import jax |
|
import numpy as np |
|
from flax.jax_utils import replicate |
|
from flax.training.common_utils import shard |
|
from diffusers import FlaxStableDiffusionPipeline |
|
|
|
model_path = "path_to_saved_model" |
|
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) |
|
|
|
prompt = "yoda naruto" |
|
prng_seed = jax.random.PRNGKey(0) |
|
num_inference_steps = 50 |
|
|
|
num_samples = jax.device_count() |
|
prompt = num_samples * [prompt] |
|
prompt_ids = pipeline.prepare_inputs(prompt) |
|
|
|
# shard inputs and rng |
|
params = replicate(params) |
|
prng_seed = jax.random.split(prng_seed, jax.device_count()) |
|
prompt_ids = shard(prompt_ids) |
|
|
|
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images |
|
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) |
|
image.save("yoda-naruto.png") |
|
``` |
|
</jax> |
|
</frameworkcontent> |