Spaces:
Runtime error
Text-to-image
text-to-image νμΈνλ μ€ν¬λ¦½νΈλ experimental μνμ λλ€. κ³Όμ ν©νκΈ° μ½κ³ μΉλͺ μ μΈ λ§κ°κ³Ό κ°μ λ¬Έμ μ λΆλͺνκΈ° μ½μ΅λλ€. μ체 λ°μ΄ν°μ μμ μ΅μμ κ²°κ³Όλ₯Ό μ»μΌλ €λ©΄ λ€μν νμ΄νΌνλΌλ―Έν°λ₯Ό νμνλ κ²μ΄ μ’μ΅λλ€.
Stable Diffusionκ³Ό κ°μ text-to-image λͺ¨λΈμ ν
μ€νΈ ν둬ννΈμμ μ΄λ―Έμ§λ₯Ό μμ±ν©λλ€. μ΄ κ°μ΄λλ PyTorch λ° Flaxλ₯Ό μ¬μ©νμ¬ μ체 λ°μ΄ν°μ
μμ CompVis/stable-diffusion-v1-4
λͺ¨λΈλ‘ νμΈνλνλ λ°©λ²μ 보μ¬μ€λλ€. μ΄ κ°μ΄λμ μ¬μ©λ text-to-image νμΈνλμ μν λͺ¨λ νμ΅ μ€ν¬λ¦½νΈμ κ΄μ¬μ΄ μλ κ²½μ° μ΄ λ¦¬ν¬μ§ν 리μμ μμΈν μ°Ύμ μ μμ΅λλ€.
μ€ν¬λ¦½νΈλ₯Ό μ€ννκΈ° μ μ, λΌμ΄λΈλ¬λ¦¬μ νμ΅ dependencyλ€μ μ€μΉν΄μΌ ν©λλ€:
pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt
κ·Έλ¦¬κ³ π€Accelerate νκ²½μ μ΄κΈ°νν©λλ€:
accelerate config
리ν¬μ§ν 리λ₯Ό μ΄λ―Έ 볡μ ν κ²½μ°, μ΄ λ¨κ³λ₯Ό μνν νμκ° μμ΅λλ€. λμ , λ‘컬 체ν¬μμ κ²½λ‘λ₯Ό νμ΅ μ€ν¬λ¦½νΈμ λͺ μν μ μμΌλ©° κ±°κΈ°μμ λ‘λλ©λλ€.
νλμ¨μ΄ μꡬ μ¬ν
gradient_checkpointing
λ° mixed_precision
μ μ¬μ©νλ©΄ λ¨μΌ 24GB GPUμμ λͺ¨λΈμ νμΈνλν μ μμ΅λλ€. λ λμ batch_size
μ λ λΉ λ₯Έ νλ ¨μ μν΄μλ GPU λ©λͺ¨λ¦¬κ° 30GB μ΄μμΈ GPUλ₯Ό μ¬μ©νλ κ²μ΄ μ’μ΅λλ€. TPU λλ GPUμμ νμΈνλμ μν΄ JAXλ Flaxλ₯Ό μ¬μ©ν μλ μμ΅λλ€. μμΈν λ΄μ©μ μλλ₯Ό μ°Έμ‘°νμΈμ.
xFormersλ‘ memory efficient attentionμ νμ±ννμ¬ λ©λͺ¨λ¦¬ μ¬μ©λ ν¨μ¬ λ μ€μΌ μ μμ΅λλ€. xFormersκ° μ€μΉλμ΄ μλμ§ νμΈνκ³ --enable_xformers_memory_efficient_attention
λ₯Ό νμ΅ μ€ν¬λ¦½νΈμ λͺ
μν©λλ€.
xFormersλ Flaxμ μ¬μ©ν μ μμ΅λλ€.
Hubμ λͺ¨λΈ μ λ‘λνκΈ°
νμ΅ μ€ν¬λ¦½νΈμ λ€μ μΈμλ₯Ό μΆκ°νμ¬ λͺ¨λΈμ νλΈμ μ μ₯ν©λλ€:
--push_to_hub
체ν¬ν¬μΈνΈ μ μ₯ λ° λΆλ¬μ€κΈ°
νμ΅ μ€ λ°μν μ μλ μΌμ λλΉνμ¬ μ κΈ°μ μΌλ‘ 체ν¬ν¬μΈνΈλ₯Ό μ μ₯ν΄ λλ κ²μ΄ μ’μ΅λλ€. 체ν¬ν¬μΈνΈλ₯Ό μ μ₯νλ €λ©΄ νμ΅ μ€ν¬λ¦½νΈμ λ€μ μΈμλ₯Ό λͺ μν©λλ€.
--checkpointing_steps=500
500μ€ν λ§λ€ μ 체 νμ΅ stateκ° 'output_dir'μ νμ ν΄λμ μ μ₯λ©λλ€. 체ν¬ν¬μΈνΈλ 'checkpoint-'μ μ§κΈκΉμ§ νμ΅λ step μμ λλ€. μλ₯Ό λ€μ΄ 'checkpoint-1500'μ 1500 νμ΅ step νμ μ μ₯λ 체ν¬ν¬μΈνΈμ λλ€.
νμ΅μ μ¬κ°νκΈ° μν΄ μ²΄ν¬ν¬μΈνΈλ₯Ό λΆλ¬μ€λ €λ©΄ '--resume_from_checkpoint' μΈμλ₯Ό νμ΅ μ€ν¬λ¦½νΈμ λͺ μνκ³ μ¬κ°ν 체ν¬ν¬μΈνΈλ₯Ό μ§μ νμμμ€. μλ₯Ό λ€μ΄ λ€μ μΈμλ 1500κ°μ νμ΅ step νμ μ μ₯λ 체ν¬ν¬μΈνΈμμλΆν° νλ ¨μ μ¬κ°ν©λλ€.
--resume_from_checkpoint="checkpoint-1500"
νμΈνλ
λ€μκ³Ό κ°μ΄ [Naruto BLIP μΊ‘μ ](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) λ°μ΄ν°μ μμ νμΈνλ μ€νμ μν΄ [PyTorch νμ΅ μ€ν¬λ¦½νΈ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)λ₯Ό μ€νν©λλ€:export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/naruto-blip-captions"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-naruto-model"
μ체 λ°μ΄ν°μ μΌλ‘ νμΈνλνλ €λ©΄ π€ Datasetsμμ μꡬνλ νμμ λ°λΌ λ°μ΄ν°μ μ μ€λΉνμΈμ. λ°μ΄ν°μ μ νλΈμ μ λ‘λνκ±°λ [νμΌλ€μ΄ μλ λ‘컬 ν΄λλ₯Ό μ€λΉ](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ν μ μμ΅λλ€.
μ¬μ©μ 컀μ€ν
loading logicμ μ¬μ©νλ €λ©΄ μ€ν¬λ¦½νΈλ₯Ό μμ νμμμ€. λμμ΄ λλλ‘ μ½λμ μ μ ν μμΉμ ν¬μΈν°λ₯Ό λ¨κ²Όμ΅λλ€. π€ μλ μμ μ€ν¬λ¦½νΈλ TRAIN_DIR
μ λ‘컬 λ°μ΄ν°μ
μΌλ‘λ₯Ό νμΈνλνλ λ°©λ²κ³Ό OUTPUT_DIR
μμ λͺ¨λΈμ μ μ₯ν μμΉλ₯Ό 보μ¬μ€λλ€:
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
export OUTPUT_DIR="path_to_save_model"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR}
[@duongna211](https://github.com/duongna21)μ κΈ°μ¬λ‘, Flaxλ₯Ό μ¬μ©ν΄ TPU λ° GPUμμ Stable Diffusion λͺ¨λΈμ λ λΉ λ₯΄κ² νμ΅ν μ μμ΅λλ€. μ΄λ TPU νλμ¨μ΄μμ λ§€μ° ν¨μ¨μ μ΄μ§λ§ GPUμμλ νλ₯νκ² μλν©λλ€. Flax νμ΅ μ€ν¬λ¦½νΈλ gradient checkpointingλ gradient accumulationκ³Ό κ°μ κΈ°λ₯μ μμ§ μ§μνμ§ μμΌλ―λ‘ λ©λͺ¨λ¦¬κ° 30GB μ΄μμΈ GPU λλ TPU v3κ° νμν©λλ€.
μ€ν¬λ¦½νΈλ₯Ό μ€ννκΈ° μ μ μꡬ μ¬νμ΄ μ€μΉλμ΄ μλμ§ νμΈνμμμ€:
pip install -U -r requirements_flax.txt
κ·Έλ¬λ©΄ λ€μκ³Ό κ°μ΄ Flax νμ΅ μ€ν¬λ¦½νΈλ₯Ό μ€νν μ μμ΅λλ€.
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export dataset_name="lambdalabs/naruto-blip-captions"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-naruto-model"
μ체 λ°μ΄ν°μ μΌλ‘ νμΈνλνλ €λ©΄ π€ Datasetsμμ μꡬνλ νμμ λ°λΌ λ°μ΄ν°μ μ μ€λΉνμΈμ. λ°μ΄ν°μ μ νλΈμ μ λ‘λνκ±°λ [νμΌλ€μ΄ μλ λ‘컬 ν΄λλ₯Ό μ€λΉ](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ν μ μμ΅λλ€.
μ¬μ©μ 컀μ€ν
loading logicμ μ¬μ©νλ €λ©΄ μ€ν¬λ¦½νΈλ₯Ό μμ νμμμ€. λμμ΄ λλλ‘ μ½λμ μ μ ν μμΉμ ν¬μΈν°λ₯Ό λ¨κ²Όμ΅λλ€. π€ μλ μμ μ€ν¬λ¦½νΈλ TRAIN_DIR
μ λ‘컬 λ°μ΄ν°μ
μΌλ‘λ₯Ό νμΈνλνλ λ°©λ²μ 보μ¬μ€λλ€:
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-naruto-model"
LoRA
Text-to-image λͺ¨λΈ νμΈνλμ μν΄, λκ·λͺ¨ λͺ¨λΈ νμ΅μ κ°μννκΈ° μν νμΈνλ κΈ°μ μΈ LoRA(Low-Rank Adaptation of Large Language Models)λ₯Ό μ¬μ©ν μ μμ΅λλ€. μμΈν λ΄μ©μ LoRA νμ΅ κ°μ΄λλ₯Ό μ°Έμ‘°νμΈμ.
μΆλ‘
νλΈμ λͺ¨λΈ κ²½λ‘ λλ λͺ¨λΈ μ΄λ¦μ [StableDiffusionPipeline
]μ μ λ¬νμ¬ μΆλ‘ μ μν΄ νμΈ νλλ λͺ¨λΈμ λΆλ¬μ¬ μ μμ΅λλ€:
model_path = "path_to_saved_model" pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) pipe.to("cuda")
image = pipe(prompt="yoda").images[0] image.save("yoda-naruto.png")
</pt>
<jax>
```python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
model_path = "path_to_saved_model"
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
prompt = "yoda naruto"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
image.save("yoda-naruto.png")