Spaces:
Runtime error
Runtime error
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
the License. You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
specific language governing permissions and limitations under the License. | |
--> | |
# Wuerstchen | |
The [Wuerstchen](https://hf.co/papers/2306.00637) model drastically reduces computational costs by compressing the latent space by 42x, without compromising image quality and accelerating inference. During training, Wuerstchen uses two models (VQGAN + autoencoder) to compress the latents, and then a third model (text-conditioned latent diffusion model) is conditioned on this highly compressed space to generate an image. | |
To fit the prior model into GPU memory and to speedup training, try enabling `gradient_accumulation_steps`, `gradient_checkpointing`, and `mixed_precision` respectively. | |
This guide explores the [train_text_to_image_prior.py](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) script to help you become more familiar with it, and how you can adapt it for your own use-case. | |
Before running the script, make sure you install the library from source: | |
```bash | |
git clone https://github.com/huggingface/diffusers | |
cd diffusers | |
pip install . | |
``` | |
Then navigate to the example folder containing the training script and install the required dependencies for the script you're using: | |
```bash | |
cd examples/wuerstchen/text_to_image | |
pip install -r requirements.txt | |
``` | |
<Tip> | |
π€ Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the π€ Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more. | |
</Tip> | |
Initialize an π€ Accelerate environment: | |
```bash | |
accelerate config | |
``` | |
To setup a default π€ Accelerate environment without choosing any configurations: | |
```bash | |
accelerate config default | |
``` | |
Or if your environment doesn't support an interactive shell, like a notebook, you can use: | |
```py | |
from accelerate.utils import write_basic_config | |
write_basic_config() | |
``` | |
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script. | |
<Tip> | |
The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn't cover every aspect of the [script](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) in detail. If you're interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns. | |
</Tip> | |
## Script parameters | |
The training scripts provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L192) function. It provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like. | |
For example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command: | |
```bash | |
accelerate launch train_text_to_image_prior.py \ | |
--mixed_precision="fp16" | |
``` | |
Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so let's dive right into the Wuerstchen training script! | |
## Training script | |
The training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support Wuerstchen. This guide focuses on the code that is unique to the Wuerstchen training script. | |
The [`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L441) function starts by initializing the image encoder - an [EfficientNet](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py) - in addition to the usual scheduler and tokenizer. | |
```py | |
with ContextManagers(deepspeed_zero_init_disabled_context_manager()): | |
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt") | |
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu") | |
image_encoder = EfficientNetEncoder() | |
image_encoder.load_state_dict(state_dict["effnet_state_dict"]) | |
image_encoder.eval() | |
``` | |
You'll also load the [`WuerstchenPrior`] model for optimization. | |
```py | |
prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") | |
optimizer = optimizer_cls( | |
prior.parameters(), | |
lr=args.learning_rate, | |
betas=(args.adam_beta1, args.adam_beta2), | |
weight_decay=args.adam_weight_decay, | |
eps=args.adam_epsilon, | |
) | |
``` | |
Next, you'll apply some [transforms](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656) to the images and [tokenize](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L637) the captions: | |
```py | |
def preprocess_train(examples): | |
images = [image.convert("RGB") for image in examples[image_column]] | |
examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images] | |
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples) | |
return examples | |
``` | |
Finally, the [training loop](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656) handles compressing the images to latent space with the `EfficientNetEncoder`, adding noise to the latents, and predicting the noise residual with the [`WuerstchenPrior`] model. | |
```py | |
pred_noise = prior(noisy_latents, timesteps, prompt_embeds) | |
``` | |
If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process. | |
## Launch the script | |
Once youβve made all your changes or youβre okay with the default configuration, youβre ready to launch the training script! π | |
Set the `DATASET_NAME` environment variable to the dataset name from the Hub. This guide uses the [Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) dataset, but you can create and train on your own datasets as well (see the [Create a dataset for training](create_dataset) guide). | |
<Tip> | |
To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. Youβll also need to add the `--validation_prompt` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results. | |
</Tip> | |
```bash | |
export DATASET_NAME="lambdalabs/naruto-blip-captions" | |
accelerate launch train_text_to_image_prior.py \ | |
--mixed_precision="fp16" \ | |
--dataset_name=$DATASET_NAME \ | |
--resolution=768 \ | |
--train_batch_size=4 \ | |
--gradient_accumulation_steps=4 \ | |
--gradient_checkpointing \ | |
--dataloader_num_workers=4 \ | |
--max_train_steps=15000 \ | |
--learning_rate=1e-05 \ | |
--max_grad_norm=1 \ | |
--checkpoints_total_limit=3 \ | |
--lr_scheduler="constant" \ | |
--lr_warmup_steps=0 \ | |
--validation_prompts="A robot naruto, 4k photo" \ | |
--report_to="wandb" \ | |
--push_to_hub \ | |
--output_dir="wuerstchen-prior-naruto-model" | |
``` | |
Once training is complete, you can use your newly trained model for inference! | |
```py | |
import torch | |
from diffusers import AutoPipelineForText2Image | |
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS | |
pipeline = AutoPipelineForText2Image.from_pretrained("path/to/saved/model", torch_dtype=torch.float16).to("cuda") | |
caption = "A cute bird naruto holding a shield" | |
images = pipeline( | |
caption, | |
width=1024, | |
height=1536, | |
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, | |
prior_guidance_scale=4.0, | |
num_images_per_prompt=2, | |
).images | |
``` | |
## Next steps | |
Congratulations on training a Wuerstchen model! To learn more about how to use your new model, the following may be helpful: | |
- Take a look at the [Wuerstchen](../api/pipelines/wuerstchen#text-to-image-generation) API documentation to learn more about how to use the pipeline for text-to-image generation and its limitations. | |