|
--- |
|
license: mit |
|
--- |
|
# A Text-Conditioned Diffusion-Prior |
|
|
|
## Training Details |
|
|
|
[Updated Reports Coming] |
|
|
|
## Source Code |
|
Models are diffusion prior trainers from https://github.com/lucidrains/DALLE2-pytorch |
|
|
|
## Community: LAION |
|
Join Us!: https://discord.gg/uPMftTmrvS |
|
|
|
--- |
|
|
|
## Intro |
|
|
|
A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful. |
|
|
|
### Motivation |
|
|
|
Before we dive into the model, let’s look at a quick example of where the model may be helpful. |
|
|
|
For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder. |
|
|
|
> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets. |
|
|
|
```python |
|
# Load Models |
|
clip_model = clip.load("ViT-L/14") |
|
decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings |
|
|
|
# Retrieve prompt from user and encode with CLIP |
|
prompt = "A corgi wearing sunglasses" |
|
tokenized_text = tokenize(prompt) |
|
text_embedding = clip_model.encode_text(tokenized_text) |
|
|
|
# Now, pass the text embedding to the decoder |
|
predicted_image = decoder.sample(text_embedding) |
|
``` |
|
|
|
> **Question**: *Can you spot the issue here?* |
|
> |
|
> **Answer**: *We’re trying to generate an image from a text embedding!* |
|
|
|
Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution |
|
|
|
```python |
|
# Load Models |
|
prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb |
|
decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings |
|
|
|
# Retrieve prompt from user and encode with a prior |
|
prompt = "A corgi wearing sunglasses" |
|
tokenized_text = tokenize(prompt) |
|
text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images! |
|
|
|
# Now, pass the predicted image embedding to the decoder |
|
predicted_image = decoder.sample(text_embedding) |
|
``` |
|
|
|
With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data. |
|
|
|
> **You may be asking yourself the following question:** |
|
> |
|
> *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"* |
|
> |
|
> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile: |
|
|
|
## Usage |
|
|
|
To utilize a pre-trained prior, it’s quite simple. |
|
|
|
### Loading Checkpoints |
|
```python |
|
import torch |
|
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter |
|
from dalle2_pytorch.trainer import DiffusionPriorTrainer |
|
|
|
def load_diffusion_model(dprior_path): |
|
|
|
prior_network = DiffusionPriorNetwork( |
|
dim=768, |
|
depth=24, |
|
dim_head=64, |
|
heads=32, |
|
normformer=True, |
|
attn_dropout=5e-2, |
|
ff_dropout=5e-2, |
|
num_time_embeds=1, |
|
num_image_embeds=1, |
|
num_text_embeds=1, |
|
num_timesteps=1000, |
|
ff_mult=4 |
|
) |
|
|
|
diffusion_prior = DiffusionPrior( |
|
net=prior_network, |
|
clip=OpenAIClipAdapter("ViT-L/14"), |
|
image_embed_dim=768, |
|
timesteps=1000, |
|
cond_drop_prob=0.1, |
|
loss_type="l2", |
|
condition_on_text_encodings=True, |
|
|
|
) |
|
|
|
trainer = DiffusionPriorTrainer( |
|
diffusion_prior=diffusion_prior, |
|
lr=1.1e-4, |
|
wd=6.02e-2, |
|
max_grad_norm=0.5, |
|
amp=False, |
|
group_wd_params=True, |
|
use_ema=True, |
|
device=device, |
|
accelerator=None, |
|
) |
|
|
|
trainer.load(dprior_path) |
|
|
|
return trainer |
|
``` |
|
|
|
Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*) |
|
|
|
### Sampling |
|
Once we have a pre-trained model, generating embeddings is quite simple! |
|
```python |
|
# tokenize the text |
|
tokenized_text = clip.tokenize("<your amazing prompt>") |
|
# predict an embedding |
|
predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0) |
|
``` |
|
|
|
The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768). |
|
|
|
> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text(). |
|
|
|
**Some things to note:** |
|
* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt. |
|
* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*. |
|
|
|
--- |
|
|
|
## Training |
|
|
|
### Overview |
|
|
|
Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration |
|
|
|
## Dataset |
|
|
|
To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader. |
|
|
|
# Looking for more info? |
|
|
|
This readme continues in the official DALLE2-pytorch repo! you can find more details on training, metrics, and more [here](https://github.com/lucidrains/DALLE2-pytorch/blob/main/prior.md) |