diff --git "a/example_images_tokenize.ipynb" "b/example_images_tokenize.ipynb" new file mode 100644--- /dev/null +++ "b/example_images_tokenize.ipynb" @@ -0,0 +1,2643 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "4924a372-61e0-4806-8820-a716b01787e8", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "VQBASE(\n", + " (encoder): Encoder(\n", + " (model): Sequential(\n", + " (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (2): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n", + " )\n", + " (4): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (5): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (6): Downsample(\n", + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n", + " )\n", + " (7): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (nin_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (8): ResnetBlock(\n", + " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (9): Downsample(\n", + " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", + " )\n", + " (10): ResnetBlock(\n", + " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (nin_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (11): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (12): Downsample(\n", + " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n", + " )\n", + " (13): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (14): AttnBlock(\n", + " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (15): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (16): AttnBlock(\n", + " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (17): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (18): AttnBlock(\n", + " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (19): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (20): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (21): Swish()\n", + " (22): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " (decoder): Decoder(\n", + " (model): Sequential(\n", + " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (2): AttnBlock(\n", + " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (3): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (4): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (5): AttnBlock(\n", + " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (6): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (7): AttnBlock(\n", + " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (8): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (9): AttnBlock(\n", + " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (10): Upsample(\n", + " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (11): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (12): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (13): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (14): Upsample(\n", + " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (15): ResnetBlock(\n", + " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (nin_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (16): ResnetBlock(\n", + " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (17): ResnetBlock(\n", + " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (18): Upsample(\n", + " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (19): ResnetBlock(\n", + " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (nin_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (20): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (21): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (22): Upsample(\n", + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (23): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (24): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (25): ResnetBlock(\n", + " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (26): GroupNorm(32, 128, eps=1e-06, affine=True)\n", + " (27): Swish()\n", + " (28): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " (quantize): Codebook(\n", + " (embedding): Embedding(8192, 256)\n", + " )\n", + " (quant_conv): Sequential(\n", + " (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (post_quant_conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n", + ")" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "import yaml\n", + "import time\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from vqvae import VQBASE\n", + "import webdataset as wds\n", + "import albumentations as A\n", + "from albumentations.pytorch import ToTensorV2\n", + "from torchvision.utils import make_grid, save_image\n", + "device = \"cuda\"\n", + "\n", + "# Load configuration for the model\n", + "with open(\"make_a_scene/img_config.yaml\", 'r') as file:\n", + " params = yaml.safe_load(file)[\"model\"]\n", + " del params[\"_target_\"]\n", + "\n", + "\n", + "# Initialize and load the second model in bfloat16\n", + "vq_vae = VQBASE(**params).to(device)\n", + "vq_vae.load_state_dict(torch.load(\"make_a_scene/checkpoint_63.0.pt\", map_location=device)[\"model\"])\n", + "vq_vae = vq_vae.to(dtype=torch.bfloat16)\n", + "vq_vae.eval().requires_grad_(False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e14eaa4d-d7c3-4f59-b385-1ab45331c150", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers.tokenization_utils_fast import PreTrainedTokenizerFast" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bfe41742-65d6-446b-b405-f21261ab3e68", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/alandao/.local/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + } + ], + "source": [ + "llama_tokenizer = PreTrainedTokenizerFast.from_pretrained(\"NousResearch/Meta-Llama-3-8B-Instruct\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8f26a6bf-f095-41c1-bc3b-fd2381867708", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Embedding(8192, 256)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Number of to-add vocabs = number of codes in codebook of vq-vae\n", + "vq_vae.quantize.embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e474eef7-affa-4991-a77a-5ef010147223", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'<|begin_of_text|>': 128000,\n", + " '<|end_of_text|>': 128001,\n", + " '<|reserved_special_token_0|>': 128002,\n", + " '<|reserved_special_token_1|>': 128003,\n", + " '<|reserved_special_token_2|>': 128004,\n", + " '<|reserved_special_token_3|>': 128005,\n", + " '<|start_header_id|>': 128006,\n", + " '<|end_header_id|>': 128007,\n", + " '<|reserved_special_token_4|>': 128008,\n", + " '<|eot_id|>': 128009,\n", + " '<|reserved_special_token_5|>': 128010,\n", + " '<|reserved_special_token_6|>': 128011,\n", + " '<|reserved_special_token_7|>': 128012,\n", + " '<|reserved_special_token_8|>': 128013,\n", + " '<|reserved_special_token_9|>': 128014,\n", + " '<|reserved_special_token_10|>': 128015,\n", + " '<|reserved_special_token_11|>': 128016,\n", + " '<|reserved_special_token_12|>': 128017,\n", + " '<|reserved_special_token_13|>': 128018,\n", + " '<|reserved_special_token_14|>': 128019,\n", + " '<|reserved_special_token_15|>': 128020,\n", + " '<|reserved_special_token_16|>': 128021,\n", + " '<|reserved_special_token_17|>': 128022,\n", + " '<|reserved_special_token_18|>': 128023,\n", + " '<|reserved_special_token_19|>': 128024,\n", + " '<|reserved_special_token_20|>': 128025,\n", + " '<|reserved_special_token_21|>': 128026,\n", + " '<|reserved_special_token_22|>': 128027,\n", + " '<|reserved_special_token_23|>': 128028,\n", + " '<|reserved_special_token_24|>': 128029,\n", + " '<|reserved_special_token_25|>': 128030,\n", + " '<|reserved_special_token_26|>': 128031,\n", + " '<|reserved_special_token_27|>': 128032,\n", + " '<|reserved_special_token_28|>': 128033,\n", + " '<|reserved_special_token_29|>': 128034,\n", + " '<|reserved_special_token_30|>': 128035,\n", + " '<|reserved_special_token_31|>': 128036,\n", + " '<|reserved_special_token_32|>': 128037,\n", + " '<|reserved_special_token_33|>': 128038,\n", + " '<|reserved_special_token_34|>': 128039,\n", + " '<|reserved_special_token_35|>': 128040,\n", + " '<|reserved_special_token_36|>': 128041,\n", + " '<|reserved_special_token_37|>': 128042,\n", + " '<|reserved_special_token_38|>': 128043,\n", + " '<|reserved_special_token_39|>': 128044,\n", + " '<|reserved_special_token_40|>': 128045,\n", + " '<|reserved_special_token_41|>': 128046,\n", + " '<|reserved_special_token_42|>': 128047,\n", + " '<|reserved_special_token_43|>': 128048,\n", + " '<|reserved_special_token_44|>': 128049,\n", + " '<|reserved_special_token_45|>': 128050,\n", + " '<|reserved_special_token_46|>': 128051,\n", + " '<|reserved_special_token_47|>': 128052,\n", + " '<|reserved_special_token_48|>': 128053,\n", + " '<|reserved_special_token_49|>': 128054,\n", + " '<|reserved_special_token_50|>': 128055,\n", + " '<|reserved_special_token_51|>': 128056,\n", + " '<|reserved_special_token_52|>': 128057,\n", + " '<|reserved_special_token_53|>': 128058,\n", + " '<|reserved_special_token_54|>': 128059,\n", + " '<|reserved_special_token_55|>': 128060,\n", + " '<|reserved_special_token_56|>': 128061,\n", + " '<|reserved_special_token_57|>': 128062,\n", + " '<|reserved_special_token_58|>': 128063,\n", + " '<|reserved_special_token_59|>': 128064,\n", + " '<|reserved_special_token_60|>': 128065,\n", + " '<|reserved_special_token_61|>': 128066,\n", + " '<|reserved_special_token_62|>': 128067,\n", + " '<|reserved_special_token_63|>': 128068,\n", + " '<|reserved_special_token_64|>': 128069,\n", + " '<|reserved_special_token_65|>': 128070,\n", + " '<|reserved_special_token_66|>': 128071,\n", + " '<|reserved_special_token_67|>': 128072,\n", + " '<|reserved_special_token_68|>': 128073,\n", + " '<|reserved_special_token_69|>': 128074,\n", + " '<|reserved_special_token_70|>': 128075,\n", + " '<|reserved_special_token_71|>': 128076,\n", + " '<|reserved_special_token_72|>': 128077,\n", + " '<|reserved_special_token_73|>': 128078,\n", + " '<|reserved_special_token_74|>': 128079,\n", + " '<|reserved_special_token_75|>': 128080,\n", + " '<|reserved_special_token_76|>': 128081,\n", + " '<|reserved_special_token_77|>': 128082,\n", + " '<|reserved_special_token_78|>': 128083,\n", + " '<|reserved_special_token_79|>': 128084,\n", + " '<|reserved_special_token_80|>': 128085,\n", + " '<|reserved_special_token_81|>': 128086,\n", + " '<|reserved_special_token_82|>': 128087,\n", + " '<|reserved_special_token_83|>': 128088,\n", + " '<|reserved_special_token_84|>': 128089,\n", + " '<|reserved_special_token_85|>': 128090,\n", + " '<|reserved_special_token_86|>': 128091,\n", + " '<|reserved_special_token_87|>': 128092,\n", + " '<|reserved_special_token_88|>': 128093,\n", + " '<|reserved_special_token_89|>': 128094,\n", + " '<|reserved_special_token_90|>': 128095,\n", + " '<|reserved_special_token_91|>': 128096,\n", + " '<|reserved_special_token_92|>': 128097,\n", + " '<|reserved_special_token_93|>': 128098,\n", + " '<|reserved_special_token_94|>': 128099,\n", + " '<|reserved_special_token_95|>': 128100,\n", + " '<|reserved_special_token_96|>': 128101,\n", + " '<|reserved_special_token_97|>': 128102,\n", + " '<|reserved_special_token_98|>': 128103,\n", + " '<|reserved_special_token_99|>': 128104,\n", + " '<|reserved_special_token_100|>': 128105,\n", + " '<|reserved_special_token_101|>': 128106,\n", + " '<|reserved_special_token_102|>': 128107,\n", + " '<|reserved_special_token_103|>': 128108,\n", + " '<|reserved_special_token_104|>': 128109,\n", + " '<|reserved_special_token_105|>': 128110,\n", + " '<|reserved_special_token_106|>': 128111,\n", + " '<|reserved_special_token_107|>': 128112,\n", + " '<|reserved_special_token_108|>': 128113,\n", + " '<|reserved_special_token_109|>': 128114,\n", + " '<|reserved_special_token_110|>': 128115,\n", + " '<|reserved_special_token_111|>': 128116,\n", + " '<|reserved_special_token_112|>': 128117,\n", + " '<|reserved_special_token_113|>': 128118,\n", + " '<|reserved_special_token_114|>': 128119,\n", + " '<|reserved_special_token_115|>': 128120,\n", + " '<|reserved_special_token_116|>': 128121,\n", + " '<|reserved_special_token_117|>': 128122,\n", + " '<|reserved_special_token_118|>': 128123,\n", + " '<|reserved_special_token_119|>': 128124,\n", + " '<|reserved_special_token_120|>': 128125,\n", + " '<|reserved_special_token_121|>': 128126,\n", + " '<|reserved_special_token_122|>': 128127,\n", + " '<|reserved_special_token_123|>': 128128,\n", + " '<|reserved_special_token_124|>': 128129,\n", + " '<|reserved_special_token_125|>': 128130,\n", + " '<|reserved_special_token_126|>': 128131,\n", + " '<|reserved_special_token_127|>': 128132,\n", + " '<|reserved_special_token_128|>': 128133,\n", + " '<|reserved_special_token_129|>': 128134,\n", + " '<|reserved_special_token_130|>': 128135,\n", + " '<|reserved_special_token_131|>': 128136,\n", + " '<|reserved_special_token_132|>': 128137,\n", + " '<|reserved_special_token_133|>': 128138,\n", + " '<|reserved_special_token_134|>': 128139,\n", + " '<|reserved_special_token_135|>': 128140,\n", + " '<|reserved_special_token_136|>': 128141,\n", + " '<|reserved_special_token_137|>': 128142,\n", + " '<|reserved_special_token_138|>': 128143,\n", + " '<|reserved_special_token_139|>': 128144,\n", + " '<|reserved_special_token_140|>': 128145,\n", + " '<|reserved_special_token_141|>': 128146,\n", + " '<|reserved_special_token_142|>': 128147,\n", + " '<|reserved_special_token_143|>': 128148,\n", + " '<|reserved_special_token_144|>': 128149,\n", + " '<|reserved_special_token_145|>': 128150,\n", + " '<|reserved_special_token_146|>': 128151,\n", + " '<|reserved_special_token_147|>': 128152,\n", + " '<|reserved_special_token_148|>': 128153,\n", + " '<|reserved_special_token_149|>': 128154,\n", + " '<|reserved_special_token_150|>': 128155,\n", + " '<|reserved_special_token_151|>': 128156,\n", + " '<|reserved_special_token_152|>': 128157,\n", + " '<|reserved_special_token_153|>': 128158,\n", + " '<|reserved_special_token_154|>': 128159,\n", + " '<|reserved_special_token_155|>': 128160,\n", + " '<|reserved_special_token_156|>': 128161,\n", + " '<|reserved_special_token_157|>': 128162,\n", + " '<|reserved_special_token_158|>': 128163,\n", + " '<|reserved_special_token_159|>': 128164,\n", + " '<|reserved_special_token_160|>': 128165,\n", + " '<|reserved_special_token_161|>': 128166,\n", + " '<|reserved_special_token_162|>': 128167,\n", + " '<|reserved_special_token_163|>': 128168,\n", + " '<|reserved_special_token_164|>': 128169,\n", + " '<|reserved_special_token_165|>': 128170,\n", + " '<|reserved_special_token_166|>': 128171,\n", + " '<|reserved_special_token_167|>': 128172,\n", + " '<|reserved_special_token_168|>': 128173,\n", + " '<|reserved_special_token_169|>': 128174,\n", + " '<|reserved_special_token_170|>': 128175,\n", + " '<|reserved_special_token_171|>': 128176,\n", + " '<|reserved_special_token_172|>': 128177,\n", + " '<|reserved_special_token_173|>': 128178,\n", + " '<|reserved_special_token_174|>': 128179,\n", + " '<|reserved_special_token_175|>': 128180,\n", + " '<|reserved_special_token_176|>': 128181,\n", + " '<|reserved_special_token_177|>': 128182,\n", + " '<|reserved_special_token_178|>': 128183,\n", + " '<|reserved_special_token_179|>': 128184,\n", + " '<|reserved_special_token_180|>': 128185,\n", + " '<|reserved_special_token_181|>': 128186,\n", + " '<|reserved_special_token_182|>': 128187,\n", + " '<|reserved_special_token_183|>': 128188,\n", + " '<|reserved_special_token_184|>': 128189,\n", + " '<|reserved_special_token_185|>': 128190,\n", + " '<|reserved_special_token_186|>': 128191,\n", + " '<|reserved_special_token_187|>': 128192,\n", + " '<|reserved_special_token_188|>': 128193,\n", + " '<|reserved_special_token_189|>': 128194,\n", + " '<|reserved_special_token_190|>': 128195,\n", + " '<|reserved_special_token_191|>': 128196,\n", + " '<|reserved_special_token_192|>': 128197,\n", + " '<|reserved_special_token_193|>': 128198,\n", + " '<|reserved_special_token_194|>': 128199,\n", + " '<|reserved_special_token_195|>': 128200,\n", + " '<|reserved_special_token_196|>': 128201,\n", + " '<|reserved_special_token_197|>': 128202,\n", + " '<|reserved_special_token_198|>': 128203,\n", + " '<|reserved_special_token_199|>': 128204,\n", + " '<|reserved_special_token_200|>': 128205,\n", + " '<|reserved_special_token_201|>': 128206,\n", + " '<|reserved_special_token_202|>': 128207,\n", + " '<|reserved_special_token_203|>': 128208,\n", + " '<|reserved_special_token_204|>': 128209,\n", + " '<|reserved_special_token_205|>': 128210,\n", + " '<|reserved_special_token_206|>': 128211,\n", + " '<|reserved_special_token_207|>': 128212,\n", + " '<|reserved_special_token_208|>': 128213,\n", + " '<|reserved_special_token_209|>': 128214,\n", + " '<|reserved_special_token_210|>': 128215,\n", + " '<|reserved_special_token_211|>': 128216,\n", + " '<|reserved_special_token_212|>': 128217,\n", + " '<|reserved_special_token_213|>': 128218,\n", + " '<|reserved_special_token_214|>': 128219,\n", + " '<|reserved_special_token_215|>': 128220,\n", + " '<|reserved_special_token_216|>': 128221,\n", + " '<|reserved_special_token_217|>': 128222,\n", + " '<|reserved_special_token_218|>': 128223,\n", + " '<|reserved_special_token_219|>': 128224,\n", + " '<|reserved_special_token_220|>': 128225,\n", + " '<|reserved_special_token_221|>': 128226,\n", + " '<|reserved_special_token_222|>': 128227,\n", + " '<|reserved_special_token_223|>': 128228,\n", + " '<|reserved_special_token_224|>': 128229,\n", + " '<|reserved_special_token_225|>': 128230,\n", + " '<|reserved_special_token_226|>': 128231,\n", + " '<|reserved_special_token_227|>': 128232,\n", + " '<|reserved_special_token_228|>': 128233,\n", + " '<|reserved_special_token_229|>': 128234,\n", + " '<|reserved_special_token_230|>': 128235,\n", + " '<|reserved_special_token_231|>': 128236,\n", + " '<|reserved_special_token_232|>': 128237,\n", + " '<|reserved_special_token_233|>': 128238,\n", + " '<|reserved_special_token_234|>': 128239,\n", + " '<|reserved_special_token_235|>': 128240,\n", + " '<|reserved_special_token_236|>': 128241,\n", + " '<|reserved_special_token_237|>': 128242,\n", + " '<|reserved_special_token_238|>': 128243,\n", + " '<|reserved_special_token_239|>': 128244,\n", + " '<|reserved_special_token_240|>': 128245,\n", + " '<|reserved_special_token_241|>': 128246,\n", + " '<|reserved_special_token_242|>': 128247,\n", + " '<|reserved_special_token_243|>': 128248,\n", + " '<|reserved_special_token_244|>': 128249,\n", + " '<|reserved_special_token_245|>': 128250,\n", + " '<|reserved_special_token_246|>': 128251,\n", + " '<|reserved_special_token_247|>': 128252,\n", + " '<|reserved_special_token_248|>': 128253,\n", + " '<|reserved_special_token_249|>': 128254,\n", + " '<|reserved_special_token_250|>': 128255,\n", + " '<|img_start|>': 128256,\n", + " '<|img_end|>': 128257,\n", + " '': 128258,\n", + " '': 128259,\n", + " '': 128260,\n", + " '': 128261,\n", + " '': 128262,\n", + " '': 128263,\n", + " '': 128264,\n", + " '': 128265,\n", + " '': 128266,\n", + " '': 128267,\n", + " '': 128268,\n", + " '': 128269,\n", + " '': 128270,\n", + " '': 128271,\n", + " '': 128272,\n", + " '': 128273,\n", + " '': 128274,\n", + " '': 128275,\n", + " '': 128276,\n", + " '': 128277,\n", + " '': 128278,\n", + " '': 128279,\n", + " '': 128280,\n", + " '': 128281,\n", + " '': 128282,\n", + " '': 128283,\n", + " '': 128284,\n", + " '': 128285,\n", + " '': 128286,\n", + " '': 128287,\n", + " '': 128288,\n", + " '': 128289,\n", + " '': 128290,\n", + " '': 128291,\n", + " '': 128292,\n", + " '': 128293,\n", + " '': 128294,\n", + " '': 128295,\n", + " '': 128296,\n", + " '': 128297,\n", + " '': 128298,\n", + " '': 128299,\n", + " '': 128300,\n", + " '': 128301,\n", + " '': 128302,\n", + " '': 128303,\n", + " '': 128304,\n", + " '': 128305,\n", + " '': 128306,\n", + " '': 128307,\n", + " '': 128308,\n", + " '': 128309,\n", + " '': 128310,\n", + " '': 128311,\n", + " '': 128312,\n", + " '': 128313,\n", + " '': 128314,\n", + " '': 128315,\n", + " '': 128316,\n", + " '': 128317,\n", + " '': 128318,\n", + " '': 128319,\n", + " '': 128320,\n", + " '': 128321,\n", + " '': 128322,\n", + " '': 128323,\n", + " '': 128324,\n", + " '': 128325,\n", + " '': 128326,\n", + " '': 128327,\n", + " '': 128328,\n", + " '': 128329,\n", + " '': 128330,\n", + " '': 128331,\n", + " '': 128332,\n", + " '': 128333,\n", + " '': 128334,\n", + " '': 128335,\n", + " '': 128336,\n", + " '': 128337,\n", + " '': 128338,\n", + " '': 128339,\n", + " '': 128340,\n", + " '': 128341,\n", + " '': 128342,\n", + " '': 128343,\n", + " '': 128344,\n", + " '': 128345,\n", + " '': 128346,\n", + " '': 128347,\n", + " '': 128348,\n", + " '': 128349,\n", + " '': 128350,\n", + " '': 128351,\n", + " '': 128352,\n", + " '': 128353,\n", + " '': 128354,\n", + " '': 128355,\n", + " '': 128356,\n", + " '': 128357,\n", + " '': 128358,\n", + " '': 128359,\n", + " '': 128360,\n", + " '': 128361,\n", + " '': 128362,\n", + " '': 128363,\n", + " '': 128364,\n", + " '': 128365,\n", + " '': 128366,\n", + " '': 128367,\n", + " '': 128368,\n", + " '': 128369,\n", + " '': 128370,\n", + " '': 128371,\n", + " '': 128372,\n", + " '': 128373,\n", + " '': 128374,\n", + " '': 128375,\n", + " '': 128376,\n", + " '': 128377,\n", + " '': 128378,\n", + " '': 128379,\n", + " '': 128380,\n", + " '': 128381,\n", + " '': 128382,\n", + " '': 128383,\n", + " '': 128384,\n", + " '': 128385,\n", + " '': 128386,\n", + " '': 128387,\n", + " '': 128388,\n", + " '': 128389,\n", + " '': 128390,\n", + " '': 128391,\n", + " '': 128392,\n", + " '': 128393,\n", + " '': 128394,\n", + " '': 128395,\n", + " '': 128396,\n", + " '': 128397,\n", + " '': 128398,\n", + " '': 128399,\n", + " '': 128400,\n", + " '': 128401,\n", + " '': 128402,\n", + " '': 128403,\n", + " '': 128404,\n", + " '': 128405,\n", + " '': 128406,\n", + " '': 128407,\n", + " '': 128408,\n", + " '': 128409,\n", + " '': 128410,\n", + " '': 128411,\n", + " '': 128412,\n", + " '': 128413,\n", + " '': 128414,\n", + " '': 128415,\n", + " '': 128416,\n", + " '': 128417,\n", + " '': 128418,\n", + " '': 128419,\n", + " '': 128420,\n", + " '': 128421,\n", + " '': 128422,\n", + " '': 128423,\n", + " '': 128424,\n", + " '': 128425,\n", + " '': 128426,\n", + " '': 128427,\n", + " '': 128428,\n", + " '': 128429,\n", + " '': 128430,\n", + " '': 128431,\n", + " '': 128432,\n", + " '': 128433,\n", + " '': 128434,\n", + " '': 128435,\n", + " '': 128436,\n", + " '': 128437,\n", + " '': 128438,\n", + " '': 128439,\n", + " '': 128440,\n", + " '': 128441,\n", + " '': 128442,\n", + " '': 128443,\n", + " '': 128444,\n", + " '': 128445,\n", + " '': 128446,\n", + " '': 128447,\n", + " '': 128448,\n", + " '': 128449,\n", + " '': 128450,\n", + " '': 128451,\n", + " '': 128452,\n", + " '': 128453,\n", + " '': 128454,\n", + " '': 128455,\n", + " '': 128456,\n", + " '': 128457,\n", + " '': 128458,\n", + " '': 128459,\n", + " '': 128460,\n", + " '': 128461,\n", + " '': 128462,\n", + " '': 128463,\n", + " '': 128464,\n", + " '': 128465,\n", + " '': 128466,\n", + " '': 128467,\n", + " '': 128468,\n", + " '': 128469,\n", + " '': 128470,\n", + " '': 128471,\n", + " '': 128472,\n", + " '': 128473,\n", + " '': 128474,\n", + " '': 128475,\n", + " '': 128476,\n", + " '': 128477,\n", + " '': 128478,\n", + " '': 128479,\n", + " '': 128480,\n", + " '': 128481,\n", + " '': 128482,\n", + " '': 128483,\n", + " '': 128484,\n", + " '': 128485,\n", + " '': 128486,\n", + " '': 128487,\n", + " '': 128488,\n", + " '': 128489,\n", + " '': 128490,\n", + " '': 128491,\n", + " '': 128492,\n", + " '': 128493,\n", + " '': 128494,\n", + " '': 128495,\n", + " '': 128496,\n", + " '': 128497,\n", + " '': 128498,\n", + " '': 128499,\n", + " '': 128500,\n", + " '': 128501,\n", + " '': 128502,\n", + " '': 128503,\n", + " '': 128504,\n", + " '': 128505,\n", + " '': 128506,\n", + " '': 128507,\n", + " '': 128508,\n", + " '': 128509,\n", + " '': 128510,\n", + " '': 128511,\n", + " '': 128512,\n", + " '': 128513,\n", + " '': 128514,\n", + " '': 128515,\n", + " '': 128516,\n", + " '': 128517,\n", + " '': 128518,\n", + " '': 128519,\n", + " '': 128520,\n", + " '': 128521,\n", + " '': 128522,\n", + " '': 128523,\n", + " '': 128524,\n", + " '': 128525,\n", + " '': 128526,\n", + " '': 128527,\n", + " '': 128528,\n", + " '': 128529,\n", + " '': 128530,\n", + " '': 128531,\n", + " '': 128532,\n", + " '': 128533,\n", + " '': 128534,\n", + " '': 128535,\n", + " '': 128536,\n", + " '': 128537,\n", + " '': 128538,\n", + " '': 128539,\n", + " '': 128540,\n", + " '': 128541,\n", + " '': 128542,\n", + " '': 128543,\n", + " '': 128544,\n", + " '': 128545,\n", + " '': 128546,\n", + " '': 128547,\n", + " '': 128548,\n", + " '': 128549,\n", + " '': 128550,\n", + " '': 128551,\n", + " '': 128552,\n", + " '': 128553,\n", + " '': 128554,\n", + " '': 128555,\n", + " '': 128556,\n", + " '': 128557,\n", + " '': 128558,\n", + " '': 128559,\n", + " '': 128560,\n", + " '': 128561,\n", + " '': 128562,\n", + " '': 128563,\n", + " '': 128564,\n", + " '': 128565,\n", + " '': 128566,\n", + " '': 128567,\n", + " '': 128568,\n", + " '': 128569,\n", + " '': 128570,\n", + " '': 128571,\n", + " '': 128572,\n", + " '': 128573,\n", + " '': 128574,\n", + " '': 128575,\n", + " '': 128576,\n", + " '': 128577,\n", + " '': 128578,\n", + " '': 128579,\n", + " '': 128580,\n", + " '': 128581,\n", + " '': 128582,\n", + " '': 128583,\n", + " '': 128584,\n", + " '': 128585,\n", + " '': 128586,\n", + " '': 128587,\n", + " '': 128588,\n", + " '': 128589,\n", + " '': 128590,\n", + " '': 128591,\n", + " '': 128592,\n", + " '': 128593,\n", + " '': 128594,\n", + " '': 128595,\n", + " '': 128596,\n", + " '': 128597,\n", + " '': 128598,\n", + " '': 128599,\n", + " '': 128600,\n", + " '': 128601,\n", + " '': 128602,\n", + " '': 128603,\n", + " '': 128604,\n", + " '': 128605,\n", + " '': 128606,\n", + " '': 128607,\n", + " '': 128608,\n", + " '': 128609,\n", + " '': 128610,\n", + " '': 128611,\n", + " '': 128612,\n", + " '': 128613,\n", + " '': 128614,\n", + " '': 128615,\n", + " '': 128616,\n", + " '': 128617,\n", + " '': 128618,\n", + " '': 128619,\n", + " '': 128620,\n", + " '': 128621,\n", + " '': 128622,\n", + " '': 128623,\n", + " '': 128624,\n", + " '': 128625,\n", + " '': 128626,\n", + " '': 128627,\n", + " '': 128628,\n", + " '': 128629,\n", + " '': 128630,\n", + " '': 128631,\n", + " '': 128632,\n", + " '': 128633,\n", + " '': 128634,\n", + " '': 128635,\n", + " '': 128636,\n", + " '': 128637,\n", + " '': 128638,\n", + " '': 128639,\n", + " '': 128640,\n", + " '': 128641,\n", + " '': 128642,\n", + " '': 128643,\n", + " '': 128644,\n", + " '': 128645,\n", + " '': 128646,\n", + " '': 128647,\n", + " '': 128648,\n", + " '': 128649,\n", + " '': 128650,\n", + " '': 128651,\n", + " '': 128652,\n", + " '': 128653,\n", + " '': 128654,\n", + " '': 128655,\n", + " '': 128656,\n", + " '': 128657,\n", + " '': 128658,\n", + " '': 128659,\n", + " '': 128660,\n", + " '': 128661,\n", + " '': 128662,\n", + " '': 128663,\n", + " '': 128664,\n", + " '': 128665,\n", + " '': 128666,\n", + " '': 128667,\n", + " '': 128668,\n", + " '': 128669,\n", + " '': 128670,\n", + " '': 128671,\n", + " '': 128672,\n", + " '': 128673,\n", + " '': 128674,\n", + " '': 128675,\n", + " '': 128676,\n", + " '': 128677,\n", + " '': 128678,\n", + " '': 128679,\n", + " '': 128680,\n", + " '': 128681,\n", + " '': 128682,\n", + " '': 128683,\n", + " '': 128684,\n", + " '': 128685,\n", + " '': 128686,\n", + " '': 128687,\n", + " '': 128688,\n", + " '': 128689,\n", + " '': 128690,\n", + " '': 128691,\n", + " '': 128692,\n", + " '': 128693,\n", + " '': 128694,\n", + " '': 128695,\n", + " '': 128696,\n", + " '': 128697,\n", + " '': 128698,\n", + " '': 128699,\n", + " '': 128700,\n", + " '': 128701,\n", + " '': 128702,\n", + " '': 128703,\n", + " '': 128704,\n", + " '': 128705,\n", + " '': 128706,\n", + " '': 128707,\n", + " '': 128708,\n", + " '': 128709,\n", + " '': 128710,\n", + " '': 128711,\n", + " '': 128712,\n", + " '': 128713,\n", + " '': 128714,\n", + " '': 128715,\n", + " '': 128716,\n", + " '': 128717,\n", + " '': 128718,\n", + " '': 128719,\n", + " '': 128720,\n", + " '': 128721,\n", + " '': 128722,\n", + " '': 128723,\n", + " '': 128724,\n", + " '': 128725,\n", + " '': 128726,\n", + " '': 128727,\n", + " '': 128728,\n", + " '': 128729,\n", + " '': 128730,\n", + " '': 128731,\n", + " '': 128732,\n", + " '': 128733,\n", + " '': 128734,\n", + " '': 128735,\n", + " '': 128736,\n", + " '': 128737,\n", + " '': 128738,\n", + " '': 128739,\n", + " '': 128740,\n", + " '': 128741,\n", + " '': 128742,\n", + " '': 128743,\n", + " '': 128744,\n", + " '': 128745,\n", + " '': 128746,\n", + " '': 128747,\n", + " '': 128748,\n", + " '': 128749,\n", + " '': 128750,\n", + " '': 128751,\n", + " '': 128752,\n", + " '': 128753,\n", + " '': 128754,\n", + " '': 128755,\n", + " '': 128756,\n", + " '': 128757,\n", + " '': 128758,\n", + " '': 128759,\n", + " '': 128760,\n", + " '': 128761,\n", + " '': 128762,\n", + " '': 128763,\n", + " '': 128764,\n", + " '': 128765,\n", + " '': 128766,\n", + " '': 128767,\n", + " '': 128768,\n", + " '': 128769,\n", + " '': 128770,\n", + " '': 128771,\n", + " '': 128772,\n", + " '': 128773,\n", + " '': 128774,\n", + " '': 128775,\n", + " '': 128776,\n", + " '': 128777,\n", + " '': 128778,\n", + " '': 128779,\n", + " '': 128780,\n", + " '': 128781,\n", + " '': 128782,\n", + " '': 128783,\n", + " '': 128784,\n", + " '': 128785,\n", + " '': 128786,\n", + " '': 128787,\n", + " '': 128788,\n", + " '': 128789,\n", + " '': 128790,\n", + " '': 128791,\n", + " '': 128792,\n", + " '': 128793,\n", + " '': 128794,\n", + " '': 128795,\n", + " '': 128796,\n", + " '': 128797,\n", + " '': 128798,\n", + " '': 128799,\n", + " '': 128800,\n", + " '': 128801,\n", + " '': 128802,\n", + " '': 128803,\n", + " '': 128804,\n", + " '': 128805,\n", + " '': 128806,\n", + " '': 128807,\n", + " '': 128808,\n", + " '': 128809,\n", + " '': 128810,\n", + " '': 128811,\n", + " '': 128812,\n", + " '': 128813,\n", + " '': 128814,\n", + " '': 128815,\n", + " '': 128816,\n", + " '': 128817,\n", + " '': 128818,\n", + " '': 128819,\n", + " '': 128820,\n", + " '': 128821,\n", + " '': 128822,\n", + " '': 128823,\n", + " '': 128824,\n", + " '': 128825,\n", + " '': 128826,\n", + " '': 128827,\n", + " '': 128828,\n", + " '': 128829,\n", + " '': 128830,\n", + " '': 128831,\n", + " '': 128832,\n", + " '': 128833,\n", + " '': 128834,\n", + " '': 128835,\n", + " '': 128836,\n", + " '': 128837,\n", + " '': 128838,\n", + " '': 128839,\n", + " '': 128840,\n", + " '': 128841,\n", + " '': 128842,\n", + " '': 128843,\n", + " '': 128844,\n", + " '': 128845,\n", + " '': 128846,\n", + " '': 128847,\n", + " '': 128848,\n", + " '': 128849,\n", + " '': 128850,\n", + " '': 128851,\n", + " '': 128852,\n", + " '': 128853,\n", + " '': 128854,\n", + " '': 128855,\n", + " '': 128856,\n", + " '': 128857,\n", + " '': 128858,\n", + " '': 128859,\n", + " '': 128860,\n", + " '': 128861,\n", + " '': 128862,\n", + " '': 128863,\n", + " '': 128864,\n", + " '': 128865,\n", + " '': 128866,\n", + " '': 128867,\n", + " '': 128868,\n", + " '': 128869,\n", + " '': 128870,\n", + " '': 128871,\n", + " '': 128872,\n", + " '': 128873,\n", + " '': 128874,\n", + " '': 128875,\n", + " '': 128876,\n", + " '': 128877,\n", + " '': 128878,\n", + " '': 128879,\n", + " '': 128880,\n", + " '': 128881,\n", + " '': 128882,\n", + " '': 128883,\n", + " '': 128884,\n", + " '': 128885,\n", + " '': 128886,\n", + " '': 128887,\n", + " '': 128888,\n", + " '': 128889,\n", + " '': 128890,\n", + " '': 128891,\n", + " '': 128892,\n", + " '': 128893,\n", + " '': 128894,\n", + " '': 128895,\n", + " '': 128896,\n", + " '': 128897,\n", + " '': 128898,\n", + " '': 128899,\n", + " '': 128900,\n", + " '': 128901,\n", + " '': 128902,\n", + " '': 128903,\n", + " '': 128904,\n", + " '': 128905,\n", + " '': 128906,\n", + " '': 128907,\n", + " '': 128908,\n", + " '': 128909,\n", + " '': 128910,\n", + " '': 128911,\n", + " '': 128912,\n", + " '': 128913,\n", + " '': 128914,\n", + " '': 128915,\n", + " '': 128916,\n", + " '': 128917,\n", + " '': 128918,\n", + " '': 128919,\n", + " '': 128920,\n", + " '': 128921,\n", + " '': 128922,\n", + " '': 128923,\n", + " '': 128924,\n", + " '': 128925,\n", + " '': 128926,\n", + " '': 128927,\n", + " '': 128928,\n", + " '': 128929,\n", + " '': 128930,\n", + " '': 128931,\n", + " '': 128932,\n", + " '': 128933,\n", + " '': 128934,\n", + " '': 128935,\n", + " '': 128936,\n", + " '': 128937,\n", + " '': 128938,\n", + " '': 128939,\n", + " '': 128940,\n", + " '': 128941,\n", + " '': 128942,\n", + " '': 128943,\n", + " '': 128944,\n", + " '': 128945,\n", + " '': 128946,\n", + " '': 128947,\n", + " '': 128948,\n", + " '': 128949,\n", + " '': 128950,\n", + " '': 128951,\n", + " '': 128952,\n", + " '': 128953,\n", + " '': 128954,\n", + " '': 128955,\n", + " '': 128956,\n", + " '': 128957,\n", + " '': 128958,\n", + " '': 128959,\n", + " '': 128960,\n", + " '': 128961,\n", + " '': 128962,\n", + " '': 128963,\n", + " '': 128964,\n", + " '': 128965,\n", + " '': 128966,\n", + " '': 128967,\n", + " '': 128968,\n", + " '': 128969,\n", + " '': 128970,\n", + " '': 128971,\n", + " '': 128972,\n", + " '': 128973,\n", + " '': 128974,\n", + " '': 128975,\n", + " '': 128976,\n", + " '': 128977,\n", + " '': 128978,\n", + " '': 128979,\n", + " '': 128980,\n", + " '': 128981,\n", + " '': 128982,\n", + " '': 128983,\n", + " '': 128984,\n", + " '': 128985,\n", + " '': 128986,\n", + " '': 128987,\n", + " '': 128988,\n", + " '': 128989,\n", + " '': 128990,\n", + " '': 128991,\n", + " '': 128992,\n", + " '': 128993,\n", + " '': 128994,\n", + " '': 128995,\n", + " '': 128996,\n", + " '': 128997,\n", + " '': 128998,\n", + " '': 128999,\n", + " ...}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llama_tokenizer.add_tokens(\"<|img_start|>\",special_tokens=True)\n", + "llama_tokenizer.add_tokens(\"<|img_end|>\",special_tokens=True)\n", + "for img_token in range(0, 8192):\n", + " padded_token = f\"\" # This pads the img_token with zeros to ensure it is 4 digits long.\n", + " llama_tokenizer.add_tokens(padded_token)\n", + "llama_tokenizer.get_added_vocab()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3ca6861c-4182-4d23-984f-92e0b5ba22f2", + "metadata": {}, + "outputs": [ + { + "ename": "AssertionError", + "evalue": "Key <|img_start|> is not a special token", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mllama_tokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_special_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m<|img_start|>\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43masdf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:947\u001b[0m, in \u001b[0;36mSpecialTokensMixin.add_special_tokens\u001b[0;34m(self, special_tokens_dict, replace_additional_special_tokens)\u001b[0m\n\u001b[1;32m 945\u001b[0m added_tokens \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m special_tokens_dict\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m--> 947\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mSPECIAL_TOKENS_ATTRIBUTES, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mKey \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is not a special token\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 949\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose:\n\u001b[1;32m 950\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAssigning \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvalue\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to the \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m key of the tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mAssertionError\u001b[0m: Key <|img_start|> is not a special token" + ] + } + ], + "source": [ + "llama_tokenizer.regist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aaf0a3d1-dd5d-402b-9802-b692921c9079", + "metadata": {}, + "outputs": [], + "source": [ + "llama_tokenizer.add_special_tokens({\"image start\":\"<|img_start|>\"})\n", + "llama_tokenizer.add_special_tokens(\"<|img_end|>\")" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "a3e9b926-2cd4-4870-b219-7f92cd9295c4", + "metadata": {}, + "outputs": [], + "source": [ + "# We need to pad the beginning position of vqvae since discrete token is range(0,8192)\n", + "pad_idx_vqvae = llama_tokenizer.vocab['']" + ] + }, + { + "cell_type": "markdown", + "id": "17f19396-b4f2-4764-a0ed-fa0bcb98ca43", + "metadata": {}, + "source": [ + "## Process some sample dataset from LAION\n", + "\n", + "Our target here is to generate a simple text only dataset that is suitable to be used by traditional tokenizers, but also suitable for representing the images" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "7998066b-6fe7-4ca8-8ede-3870f37f3725", + "metadata": {}, + "outputs": [], + "source": [ + "import albumentations as A\n", + "from albumentations.pytorch import ToTensorV2\n", + "import torch\n", + "import webdataset as wds\n", + "\n", + "def process_data(data):\n", + " pretransforms = A.Compose([\n", + " A.SmallestMaxSize(512),\n", + " A.CenterCrop(512, 512, always_apply=True),\n", + " ToTensorV2()\n", + " ])\n", + " data[\"jpg\"] = pretransforms(image=data[\"jpg\"])[\"image\"]\n", + " # Convert image to bfloat16\n", + " data[\"jpg\"] = data[\"jpg\"].to(torch.bfloat16)\n", + " return data\n", + "\n", + "url = \"file:make_a_scene/00000.tar\"\n", + "dataset = wds.WebDataset(url).decode(\"rgb\").map(process_data).to_tuple(\"jpg\", \"txt\")\n", + "\n", + "def collate(batch):\n", + " images = torch.stack([i[0] for i in batch], dim=0)\n", + " captions = [i[1] for i in batch]\n", + " return [images, captions]\n", + "\n", + "loader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=collate)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "8919fa2e-4e07-48d9-876f-c033a0fd1ab8", + "metadata": {}, + "outputs": [], + "source": [ + "image_text_dataset = [] " + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "c97bf893-aeda-43e4-b703-bdb425dcda3e", + "metadata": {}, + "outputs": [], + "source": [ + "counter = 0\n", + "for image,description in dataset:\n", + " counter += 1\n", + " if counter > 10:\n", + " break\n", + " tokenized_image_text = []\n", + "\n", + " # Get the tokens for the image part, do not forget to pad the position of codebook\n", + " discrete_tokens_padded = list(vq_vae.encode(image.to(device).unsqueeze(0))[2] + pad_idx_vqvae)\n", + "\n", + " # Get the tokens of the image description\n", + " describe_text = f\"\"\"DESCRIPTION:\n", + " {description}\n", + " IMAGE:\n", + " \"\"\"\n", + "\n", + " describe_text_tokens = llama_tokenizer.encode(describe_text)\n", + "\n", + " pos_img_start = llama_tokenizer.vocab['<|img_start|>']\n", + " pos_img_end = llama_tokenizer.vocab['<|img_end|>']\n", + " # Combine the tokens of image and text\n", + " tokenized_image_text = describe_text_tokens + [pos_img_start] + discrete_tokens_padded + [pos_img_end]\n", + "\n", + " # Reconstruct the text\n", + " recontructed_text = llama_tokenizer.decode(tokenized_image_text) \n", + " image_text_dataset.append(recontructed_text)" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "320c91a9-e0c0-4786-8eaa-424fa5e8e41e", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<|begin_of_text|>DESCRIPTION:\n", + " No Choc Easter Gifts for Babies First Easter Shoes\n", + " IMAGE:\n", + " <|img_start|><|img_end|>\n" + ] + } + ], + "source": [ + "print(image_text_dataset[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "85037f9a-91aa-4991-ae41-251e926343b9", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['<|begin_of_text|>',\n", + " 'DESCRIPTION',\n", + " ':Ċ',\n", + " 'ĠĠĠ',\n", + " 'ĠNo',\n", + " 'ĠCh',\n", + " 'oc',\n", + " 'ĠEaster',\n", + " 'ĠGifts',\n", + " 'Ġfor',\n", + " 'ĠBabies',\n", + " 'ĠFirst',\n", + " 'ĠEaster',\n", + " 'ĠShoes',\n", + " 'Ċ',\n", + " 'ĠĠĠ',\n", + " 'ĠIMAGE',\n", + " ':Ċ',\n", + " 'ĠĠĠĠ',\n", + " '<|img_start|>',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " ...]" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llama_tokenizer.tokenize(image_text_dataset[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4495b402-95be-48c2-8829-f240ac883969", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}