|
## CLIPA |
|
|
|
In this work, we present a surprising finding that there exists an _inverse_ scaling law for CLIP training, |
|
whereby the larger the image/text encoders used, the shorter the sequence length of image/text tokens that can be applied in training. |
|
Moreover, we showcase that the strategy for reducing image/text token length plays a crucial role in determining the quality of this scaling law. |
|
|
|
 |
|
|
|
As a result of this finding, we are able to successfully train CLIP even by using academic resources. |
|
For example, on an A100 eight-GPU server, our CLIP models achieve zero-shot top-1 ImageNet accuracies of **63.2%** in about **2 days**, |
|
**67.8%** in about **3 days**, and **69.3%** in about **4 days**. |
|
|
|
Moreover, We find that CLIPA at scale leads to state-of-the-art performance. For example, our CLIPA-v2 H/14 achieves a zero-shot top-1 ImageNet accuracy of **81.8%**, |
|
with a budget less than **$15000**. |
|
|
|
 |
|
|
|
For more details, please see our paper [An Inverse Scaling Law for CLIP Training](https://arxiv.org/abs/2305.07017) and |
|
[CLIPA-v2: Scaling CLIP Training with 81.1% Zero-shot ImageNet Accuracy within a $10,000 Budget; An Extra $4,000 Unlocks 81.8% Accuracy](https://arxiv.org/abs/2306.15658). |
|
|
|
|
|
Eight token length reduction strategies are investigated in this work, detailed as follows. |
|
|
|
|
|
## Image token length reduction |
|
|
|
 |
|
|
|
* `resize`: use `--force-image-size` to specify the image size you want to adopt. We find this strategy generally works the best as it retains full image information. |
|
|
|
* `random mask`: Randomly mask out image patches. use `--force-patch-dropout` to specify the mask ratio you want to adopt. |
|
|
|
* `grid mask`: Preserve one patch in each 2 × 2 grid window. We do not provide implementation for grid masking, as it is only experimental and we generally find resizing works better. |
|
|
|
* `block mask`: Keep a single block and remove other patches. We do not provide implementation for block masking, as it is only experimental and we generally find resizing works better. |
|
|
|
|
|
## Text token length reduction |
|
|
|
* `syntax mask`: Assign different masking priorities to parts of speech. Specify `"text_mask": syntax` in `"tokenizer_kwargs"` in `"text_cfg"` of model config `json` file to use. |
|
Specifically, we prioritize retaining nouns, followed by adjectives, and then other words. |
|
We find this strategy generally works the best as it retains critical information for contrastive learning. |
|
|
|
* `truncate`: Truncation selects the first N text tokens and discards the rest. This is the default setting of `open_clip`. |
|
|
|
* `random mask`: Randomly drops a portion of the text tokens. Specify `"text_mask": random` in `"tokenizer_kwargs"` in `"text_cfg"` of model config `json` file to use. |
|
|
|
* `block mask`: Randomly preserves consecutive text sequences. Specify `"text_mask": block` in `"tokenizer_kwargs"` in `"text_cfg"` of model config `json` file to use. |
|
|
|
|
|
## Installation |
|
|
|
The installation is really the same as `open_clip`, except for the usage of Natural Language Toolkit (NLTK) in `syntax mask` of text token length reduction. |
|
Please follow the [official doc](https://www.nltk.org/) to install NLTK. |
|
|
|
Note that the the usage of NLTK brings two constraints: |
|
* Because certain functions like `nltk.pos_tag` from NLTK only support English and Russian for now, the `syntax mask` only works for English. |
|
we have not tested it on Russian or any other language. Theoretically, it should work the same, given a proper language processing toolkit for other languages. |
|
If you still want to apply `syntax mask` on other languages, try finding the right toolkit. Otherwise, use other text token length reduction strategies |
|
* some modules of NLTK like `punkt` or `averaged_perceptron_tagger` need to be downloaded first before using NLTK. |
|
We have included the downloading code in `tokenizer.py`, but this might cause trouble in certain cases. |
|
You may want to manually download those modules first, by `nltk.download('punkt')` and `nltk.download('averaged_perceptron_tagger')`, |
|
and then setup the environmental variable before running the script `export NLTK_DATA=cache`. |
|
Note that this is a one-time effort. Remember to comment out those `nltk.download` lines in `tokenizer.py` afterwards. |
|
|
|
## Training |
|
We provide example scripts to reproduce our CLIPA results on an A100 eight-GPU machine under path `docs/script_examples/clipa`. |
|
|
|
For instance, to reproduce the CLIPA-L16(I37,T8) results, first run the pre-training script |
|
``` |
|
bash docs/script_examples/clipa/vit_l16/i37_t8_pretrain.sh |
|
``` |
|
and fine-tune the pre-trained checkpoint with |
|
``` |
|
bash docs/script_examples/clipa/vit_l16/i37_t8_finetune.sh |
|
``` |
|
- Remember to change the path to dataset to your own path. |
|
- This is a two-stage training pipeline. Remember to change the path to pre-trained checkpoint to your own when fine-tuning. |
|
- The training time is ~3 days for pre-training and ~1 day for fine-tuning on an A100 eight-GPU machine. |
|
|
|
## Model Weights |
|
Below are CLIPA trained weights on LAION-400M with an A100 eight-GPU machine. |
|
All models are pre-trained for 6 epochs with reduced input token lengths and subsequently fine-tuned for 0.36 epoch with full input token lengths. |
|
|
|
|
|
| | Pre-trained Weights | zero-shot IN-1K | |
|
|---------------------|:----------------------------------------------------------------------------------------------:|:---------------:| |
|
| CLIPA-B/16(I50,T16) | [download](https://drive.google.com/file/d/1MDpz8gV2Vjaazk16rBhLxU8811U7_cGL/view?usp=sharing) | 59.7 | |
|
| CLIPA-L/16(I17,T16) | [download](https://drive.google.com/file/d/1Tr2GYiKAaMH6EGIn5l7eX_1K20eaA3WA/view?usp=sharing) | 60.3 | |
|
| CLIPA_L/16(I37,T8) | [download](https://drive.google.com/file/d/1EM1ChRNARpLckkJjf6m7njCY3xyvpGBu/view?usp=sharing) | 57.9 | |
|
|
|
| | Fine-tuned Weights | zero-shot IN-1K | |
|
|---------------------|:----------------------------------------------------------------------------------------------:|:-----:| |
|
| CLIPA-B/16(I50,T16) | [download](https://drive.google.com/file/d/1fURK0K_a3-83jVEI4PVEbnEJb_V6UbGv/view?usp=sharing) | 63.2 | |
|
| CLIPA-L/16(I17,T16) | [download](https://drive.google.com/file/d/18qqZGOTGOgb3I3JWONuat6qObsgLq7sR/view?usp=sharing) | 67.8 | |
|
| CLIPA_L/16(I37,T8) | [download](https://drive.google.com/file/d/1lV7pLORUK04T9QKKx9TpYtMws-AZrib0/view?usp=sharing) | 69.3 | |
|
|
|
|
|
## CLIPA-v2 |
|
We also provide example scripts to reproduce our CLIPA-v2 H/14 results under path `docs/script_examples/clipav2`. |
|
Note that the original results are obtained with [our JAX implementation](https://github.com/UCSC-VLAA/CLIPA/tree/master/clipa_jax). |
|
These scripts are written after manually scanning the JAX config files. |
|
As it is infeasible for us to retrain those models again with pytorch, its correctness cannot be verified with 100% confidence. Use them at your own discretion. |
|
|