diff --git "a/image_classification_timm_peft_lora.ipynb" "b/image_classification_timm_peft_lora.ipynb" new file mode 100644--- /dev/null +++ "b/image_classification_timm_peft_lora.ipynb" @@ -0,0 +1,744 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4ef57047", + "metadata": {}, + "source": [ + "# Using PEFT with timm" + ] + }, + { + "cell_type": "markdown", + "id": "80561acc", + "metadata": {}, + "source": [ + "`peft` allows us to train any model with LoRA as long as the layer type is supported. Since `Conv2D` is one of the supported layer types, it makes sense to test it on image models.\n", + "\n", + "In this short notebook, we will demonstrate this with an image classification task using [`timm`](https://huggingface.co/docs/timm/index)." + ] + }, + { + "cell_type": "markdown", + "id": "aa26c285", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "markdown", + "id": "552b9040", + "metadata": {}, + "source": [ + "Make sure that you have the latest version of `peft` installed. To ensure that, run this in your Python environment:\n", + " \n", + " python -m pip install --upgrade peft\n", + " \n", + "Also, ensure that `timm` is installed:\n", + "\n", + " python -m pip install --upgrade timm" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e600b7d5", + "metadata": {}, + "outputs": [], + "source": [ + "import timm\n", + "import torch\n", + "from PIL import Image\n", + "from timm.data import resolve_data_config\n", + "from timm.data.transforms_factory import create_transform" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "73a2ae54", + "metadata": {}, + "outputs": [], + "source": [ + "import peft\n", + "from datasets import load_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "82c628fd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "markdown", + "id": "701ab69c", + "metadata": {}, + "source": [ + "## Loading the pre-trained base model" + ] + }, + { + "cell_type": "markdown", + "id": "20bff51a", + "metadata": {}, + "source": [ + "We use a small pretrained `timm` model, `PoolFormer`. Find more info on its [model card](https://huggingface.co/timm/poolformer_m36.sail_in1k)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "495cb3d6", + "metadata": {}, + "outputs": [], + "source": [ + "model_id_timm = \"timm/poolformer_m36.sail_in1k\"" + ] + }, + { + "cell_type": "markdown", + "id": "2dc06f9b", + "metadata": {}, + "source": [ + "We tell `timm` that we deal with 3 classes, to ensure that the classification layer has the correct size." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "090564bc", + "metadata": {}, + "outputs": [], + "source": [ + "model = timm.create_model(model_id_timm, pretrained=True, num_classes=3)" + ] + }, + { + "cell_type": "markdown", + "id": "beca5794", + "metadata": {}, + "source": [ + "These are the transformations steps necessary to process the image." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9df2e113", + "metadata": {}, + "outputs": [], + "source": [ + "transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))" + ] + }, + { + "cell_type": "markdown", + "id": "3f809dfa", + "metadata": {}, + "source": [ + "## Data" + ] + }, + { + "cell_type": "markdown", + "id": "a398fe22", + "metadata": {}, + "source": [ + "For this exercise, we use the \"beans\" dataset. More details on the dataset can be found on [its datasets page](https://huggingface.co/datasets/beans). For our purposes, what's important is that we have image inputs and the target we're trying to predict is one of three classes for each image." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0fddc704", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset beans (/home/vinh/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "05592574da474b81ab736d6babb5e19d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_train[0][\"image\"]" + ] + }, + { + "cell_type": "markdown", + "id": "880ea6c4", + "metadata": {}, + "source": [ + "We define a small processing function which is responsible for loading and transforming the images, as well as extracting the labels." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "142df842", + "metadata": {}, + "outputs": [], + "source": [ + "def process(batch):\n", + " x = torch.cat([transform(img).unsqueeze(0) for img in batch[\"image\"]])\n", + " y = torch.tensor(batch[\"labels\"])\n", + " return {\"x\": x, \"y\": y}" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "9744257b", + "metadata": {}, + "outputs": [], + "source": [ + "ds_train.set_transform(process)\n", + "ds_valid.set_transform(process)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "282374be", + "metadata": {}, + "outputs": [], + "source": [ + "train_loader = torch.utils.data.DataLoader(ds_train, batch_size=32)\n", + "valid_loader = torch.utils.data.DataLoader(ds_valid, batch_size=32)" + ] + }, + { + "cell_type": "markdown", + "id": "5dcd3329", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "markdown", + "id": "969bc374", + "metadata": {}, + "source": [ + "This is just a function that performs the train loop, nothing fancy happening." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "b9fc9588", + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, optimizer, criterion, train_dataloader, valid_dataloader, epochs):\n", + " for epoch in range(epochs):\n", + " model.train()\n", + " train_loss = 0\n", + " for batch in train_dataloader:\n", + " xb, yb = batch[\"x\"], batch[\"y\"]\n", + " xb, yb = xb.to(device), yb.to(device)\n", + " outputs = model(xb)\n", + " lsm = torch.nn.functional.log_softmax(outputs, dim=-1)\n", + " loss = criterion(lsm, yb)\n", + " train_loss += loss.detach().float()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + " model.eval()\n", + " valid_loss = 0\n", + " correct = 0\n", + " n_total = 0\n", + " for batch in valid_dataloader:\n", + " xb, yb = batch[\"x\"], batch[\"y\"]\n", + " xb, yb = xb.to(device), yb.to(device)\n", + " with torch.no_grad():\n", + " outputs = model(xb)\n", + " lsm = torch.nn.functional.log_softmax(outputs, dim=-1)\n", + " loss = criterion(lsm, yb)\n", + " valid_loss += loss.detach().float()\n", + " correct += (outputs.argmax(-1) == yb).sum().item()\n", + " n_total += len(yb)\n", + "\n", + " train_loss_total = (train_loss / len(train_dataloader)).item()\n", + " valid_loss_total = (valid_loss / len(valid_dataloader)).item()\n", + " valid_acc_total = correct / n_total\n", + " print(f\"{epoch=:<2} {train_loss_total=:.4f} {valid_loss_total=:.4f} {valid_acc_total=:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3fd58357", + "metadata": {}, + "source": [ + "### Selecting which layers to fine-tune with LoRA" + ] + }, + { + "cell_type": "markdown", + "id": "7987321c", + "metadata": {}, + "source": [ + "Let's take a look at the layers of our model. We only print the first 30, since there are quite a few:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "55a7be4d", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[('', timm.models.metaformer.MetaFormer),\n", + " ('stem', timm.models.metaformer.Stem),\n", + " ('stem.conv', torch.nn.modules.conv.Conv2d),\n", + " ('stem.norm', torch.nn.modules.linear.Identity),\n", + " ('stages', torch.nn.modules.container.Sequential),\n", + " ('stages.0', timm.models.metaformer.MetaFormerStage),\n", + " ('stages.0.downsample', torch.nn.modules.linear.Identity),\n", + " ('stages.0.blocks', torch.nn.modules.container.Sequential),\n", + " ('stages.0.blocks.0', timm.models.metaformer.MetaFormerBlock),\n", + " ('stages.0.blocks.0.norm1', timm.layers.norm.GroupNorm1),\n", + " ('stages.0.blocks.0.token_mixer', timm.models.metaformer.Pooling),\n", + " ('stages.0.blocks.0.token_mixer.pool', torch.nn.modules.pooling.AvgPool2d),\n", + " ('stages.0.blocks.0.drop_path1', torch.nn.modules.linear.Identity),\n", + " ('stages.0.blocks.0.layer_scale1', timm.models.metaformer.Scale),\n", + " ('stages.0.blocks.0.res_scale1', torch.nn.modules.linear.Identity),\n", + " ('stages.0.blocks.0.norm2', timm.layers.norm.GroupNorm1),\n", + " ('stages.0.blocks.0.mlp', timm.layers.mlp.Mlp),\n", + " ('stages.0.blocks.0.mlp.fc1', torch.nn.modules.conv.Conv2d),\n", + " ('stages.0.blocks.0.mlp.act', torch.nn.modules.activation.GELU),\n", + " ('stages.0.blocks.0.mlp.drop1', torch.nn.modules.dropout.Dropout),\n", + " ('stages.0.blocks.0.mlp.norm', torch.nn.modules.linear.Identity),\n", + " ('stages.0.blocks.0.mlp.fc2', torch.nn.modules.conv.Conv2d),\n", + " ('stages.0.blocks.0.mlp.drop2', torch.nn.modules.dropout.Dropout),\n", + " ('stages.0.blocks.0.drop_path2', torch.nn.modules.linear.Identity),\n", + " ('stages.0.blocks.0.layer_scale2', timm.models.metaformer.Scale),\n", + " ('stages.0.blocks.0.res_scale2', torch.nn.modules.linear.Identity),\n", + " ('stages.0.blocks.1', timm.models.metaformer.MetaFormerBlock),\n", + " ('stages.0.blocks.1.norm1', timm.layers.norm.GroupNorm1),\n", + " ('stages.0.blocks.1.token_mixer', timm.models.metaformer.Pooling),\n", + " ('stages.0.blocks.1.token_mixer.pool', torch.nn.modules.pooling.AvgPool2d)]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[(n, type(m)) for n, m in model.named_modules()][:30]" + ] + }, + { + "cell_type": "markdown", + "id": "09af9349", + "metadata": {}, + "source": [ + "Most of these layers are not good targets for LoRA, but we see a couple that should interest us. Their names are `'stages.0.blocks.0.mlp.fc1'`, etc. With a bit of regex, we can match them easily.\n", + "\n", + "Also, we should inspect the name of the classification layer, since we want to train that one too!" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8b98d9ef", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[('head.global_pool.flatten', torch.nn.modules.linear.Identity),\n", + " ('head.norm', timm.layers.norm.LayerNorm2d),\n", + " ('head.flatten', torch.nn.modules.flatten.Flatten),\n", + " ('head.drop', torch.nn.modules.linear.Identity),\n", + " ('head.fc', torch.nn.modules.linear.Linear)]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[(n, type(m)) for n, m in model.named_modules()][-5:]" + ] + }, + { + "cell_type": "markdown", + "id": "00e75b78", + "metadata": {}, + "source": [ + " config = peft.LoraConfig(\n", + " r=8,\n", + " target_modules=r\".*\\.mlp\\.fc\\d|head\\.fc\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "23814d70", + "metadata": {}, + "source": [ + "Okay, this gives us all the information we need to fine-tune this model. With a bit of regex, we match the convolutional layers that should be targeted for LoRA. We also want to train the classification layer `'head.fc'` (without LoRA), so we add it to the `modules_to_save`." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "81029587", + "metadata": {}, + "outputs": [], + "source": [ + "config = peft.LoraConfig(r=8, target_modules=r\".*\\.mlp\\.fc\\d\", modules_to_save=[\"head.fc\"])" + ] + }, + { + "cell_type": "markdown", + "id": "e05876bc", + "metadata": {}, + "source": [ + "Finally, let's create the `peft` model, the optimizer and criterion, and we can get started. As shown below, less than 2% of the model's total parameters are updated thanks to `peft`." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "8cc5c5db", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainable params: 1,064,454 || all params: 56,467,974 || trainable%: 1.88505789139876\n" + ] + } + ], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "peft_model = peft.get_peft_model(model, config).to(device)\n", + "optimizer = torch.optim.Adam(peft_model.parameters(), lr=2e-4)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "peft_model.print_trainable_parameters()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "9e557e42", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch=0 train_loss_total=1.2999 valid_loss_total=1.0624 valid_acc_total=0.4436\n", + "epoch=1 train_loss_total=1.0200 valid_loss_total=0.8906 valid_acc_total=0.7594\n", + "epoch=2 train_loss_total=0.8874 valid_loss_total=0.6894 valid_acc_total=0.8045\n", + "epoch=3 train_loss_total=0.7440 valid_loss_total=0.4797 valid_acc_total=0.8045\n", + "epoch=4 train_loss_total=0.6025 valid_loss_total=0.3419 valid_acc_total=0.8120\n", + "epoch=5 train_loss_total=0.4820 valid_loss_total=0.2589 valid_acc_total=0.8421\n", + "epoch=6 train_loss_total=0.3567 valid_loss_total=0.2101 valid_acc_total=0.8722\n", + "epoch=7 train_loss_total=0.2835 valid_loss_total=0.1385 valid_acc_total=0.9098\n", + "epoch=8 train_loss_total=0.1815 valid_loss_total=0.1108 valid_acc_total=0.9474\n", + "epoch=9 train_loss_total=0.1341 valid_loss_total=0.0785 valid_acc_total=0.9699\n", + "CPU times: user 4min 3s, sys: 36.3 s, total: 4min 40s\n", + "Wall time: 3min 32s\n" + ] + } + ], + "source": [ + "%time train(peft_model, optimizer, criterion, train_loader, valid_dataloader=valid_loader, epochs=10)" + ] + }, + { + "cell_type": "markdown", + "id": "94162859", + "metadata": {}, + "source": [ + "We get an accuracy of ~0.97, despite only training a tiny amount of parameters. That's a really nice result." + ] + }, + { + "cell_type": "markdown", + "id": "9c16bad8", + "metadata": {}, + "source": [ + "## Sharing the model through Hugging Face Hub" + ] + }, + { + "cell_type": "markdown", + "id": "2e1e16c7", + "metadata": {}, + "source": [ + "### Pushing the model to Hugging Face Hub" + ] + }, + { + "cell_type": "markdown", + "id": "ec596b3b", + "metadata": {}, + "source": [ + "If we want to share the fine-tuned weights with the world, we can upload them to Hugging Face Hub like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b583579d", + "metadata": {}, + "outputs": [], + "source": [ + "user = \"BenjaminB\" # put your user name here\n", + "model_name = \"peft-lora-with-timm-model\"\n", + "model_id = f\"{user}/{model_name}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "f1db67e4", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "aed1f9c3fa334be1b5f208efe5ba27e6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Upload 1 LFS files: 0%| | 0/1 [00:00