diff --git "a/demo.ipynb" "b/demo.ipynb" --- "a/demo.ipynb" +++ "b/demo.ipynb" @@ -4,17 +4,9 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/Weiyao.Wang/virtualenvs/Kanzo/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ + "import IPython.display as display\n", "from src.aki import AKI\n", "from transformers import AutoTokenizer, AutoConfig\n", "from torchvision.transforms import Compose, Resize, Lambda, ToTensor, Normalize\n", @@ -23,7 +15,12 @@ " from torchvision.transforms import InterpolationMode\n", " BICUBIC = InterpolationMode.BICUBIC\n", "except ImportError:\n", - " BICUBIC = Image.BICUBIC" + " BICUBIC = Image.BICUBIC\n", + "\n", + "# replace GenerationMixin to modify attention mask handling\n", + "from transformers.generation.utils import GenerationMixin\n", + "from src.aki_generation import _aki_update_model_kwargs_for_generation\n", + "GenerationMixin._update_model_kwargs_for_generation = _aki_update_model_kwargs_for_generation" ] }, { @@ -49,11 +46,6 @@ " n_px = getattr(config, \"n_px\", 384)\n", " norm_mean = getattr(config, \"norm_mean\", 0.5)\n", " norm_std = getattr(config, \"norm_std\", 0.5)\n", - "\n", - " # replace GenerationMixin to modify attention mask handling\n", - " from transformers.generation.utils import GenerationMixin\n", - " from open_flamingo import _aki_update_model_kwargs_for_generation\n", - " GenerationMixin._update_model_kwargs_for_generation = _aki_update_model_kwargs_for_generation\n", " \n", " tokenizer = AutoTokenizer.from_pretrained(ckpt_path)\n", " model = AKI.from_pretrained(ckpt_path, tokenizer=tokenizer)\n", @@ -71,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -81,10 +73,23 @@ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.\n", - "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.\n", - "Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00, 1.52s/it]\n" + "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.\n" ] }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "64e0aee907ed4b29b238f38b74762f95", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", @@ -164,13 +182,7 @@ "\n", "Describe the scene of this image.<|end|>\n", "<|assistant|>\n", - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "\n", "Response:\n", " The image captures a beautiful autumn day in a park, with a pathway covered in a vibrant carpet of fallen leaves. The leaves are in various shades of red, orange, yellow, and brown, creating a warm and colorful atmosphere. The path is lined with trees displaying beautiful autumn foliage, adding to the picturesque setting.\n", "\n", @@ -184,13 +196,6 @@ "response = process_input(image_path, text_input)\n", "print(\"Response:\\n\", response)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {