Vision-CAIR commited on
Commit
879e56f
·
1 Parent(s): 2ef30bd

Delete develop.ipynb

Browse files
Files changed (1) hide show
  1. develop.ipynb +0 -929
develop.ipynb DELETED
@@ -1,929 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "d5ac353e",
7
- "metadata": {
8
- "pycharm": {
9
- "name": "#%%\n"
10
- }
11
- },
12
- "outputs": [],
13
- "source": [
14
- "import argparse\n",
15
- "import os\n",
16
- "import shutil\n",
17
- "import random\n",
18
- "from PIL import Image\n",
19
- "\n",
20
- "import numpy as np\n",
21
- "import torch\n",
22
- "import torch.backends.cudnn as cudnn\n",
23
- "from transformers import StoppingCriteria, StoppingCriteriaList\n",
24
- "\n",
25
- "import lavis.tasks as tasks\n",
26
- "from lavis.common.config import Config\n",
27
- "from lavis.common.dist_utils import get_rank, init_distributed_mode\n",
28
- "from lavis.common.logger import setup_logger\n",
29
- "from lavis.common.optims import (\n",
30
- " LinearWarmupCosineLRScheduler,\n",
31
- " LinearWarmupStepLRScheduler,\n",
32
- ")\n",
33
- "from lavis.common.registry import registry\n",
34
- "from lavis.common.utils import now\n",
35
- "\n",
36
- "# imports modules for registration\n",
37
- "from lavis.datasets.builders import *\n",
38
- "from lavis.models import *\n",
39
- "from lavis.processors import *\n",
40
- "from lavis.runners import *\n",
41
- "from lavis.tasks import *"
42
- ]
43
- },
44
- {
45
- "cell_type": "code",
46
- "execution_count": null,
47
- "id": "4fdef7a6",
48
- "metadata": {
49
- "pycharm": {
50
- "name": "#%%\n"
51
- }
52
- },
53
- "outputs": [],
54
- "source": [
55
- "shutil.copytree('/ibex/project/c2133/vicuna', '/tmp/vicuna')"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": 2,
61
- "id": "661f9e80",
62
- "metadata": {
63
- "pycharm": {
64
- "name": "#%%\n"
65
- }
66
- },
67
- "outputs": [],
68
- "source": [
69
- "class StoppingCriteriaSub(StoppingCriteria):\n",
70
- "\n",
71
- " def __init__(self, stops = [], encounters=1):\n",
72
- " super().__init__()\n",
73
- " self.stops = stops\n",
74
- "\n",
75
- " def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):\n",
76
- " for stop in self.stops:\n",
77
- " if torch.all((stop == input_ids[0][-len(stop):])).item():\n",
78
- " return True\n",
79
- "\n",
80
- " return False\n",
81
- "\n",
82
- "\n",
83
- "stop_words_ids = [torch.tensor([835]).to('cuda:0'), \n",
84
- " torch.tensor([2277, 29937]).to('cuda:0')] # '###' can be encoded in different ways.\n",
85
- "stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])"
86
- ]
87
- },
88
- {
89
- "cell_type": "code",
90
- "execution_count": 6,
91
- "id": "1822a77a",
92
- "metadata": {
93
- "pycharm": {
94
- "name": "#%%\n"
95
- }
96
- },
97
- "outputs": [],
98
- "source": [
99
- "parser = argparse.ArgumentParser(description=\"Training\")\n",
100
- "\n",
101
- "parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n",
102
- "parser.add_argument(\n",
103
- " \"--options\",\n",
104
- " nargs=\"+\",\n",
105
- " help=\"override some settings in the used config, the key-value pair \"\n",
106
- " \"in xxx=yyy format will be merged into config file (deprecate), \"\n",
107
- " \"change to --cfg-options instead.\",\n",
108
- ")\n",
109
- "\n",
110
- "args = parser.parse_args([\"--cfg-path\", \"lavis/projects/blip2/train/vicuna_pretrain_stage2_cc.yaml\"])\n",
111
- "\n",
112
- "cfg = Config(args)\n",
113
- "device = 'cuda:0'"
114
- ]
115
- },
116
- {
117
- "cell_type": "code",
118
- "execution_count": 4,
119
- "id": "57e90f19",
120
- "metadata": {
121
- "pycharm": {
122
- "name": "#%%\n"
123
- }
124
- },
125
- "outputs": [],
126
- "source": [
127
- "vis_processor_cfg = cfg.datasets_cfg.cc_combine.vis_processor.train\n",
128
- "vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)"
129
- ]
130
- },
131
- {
132
- "cell_type": "code",
133
- "execution_count": 7,
134
- "id": "4cc521da",
135
- "metadata": {
136
- "pycharm": {
137
- "name": "#%%\n"
138
- }
139
- },
140
- "outputs": [
141
- {
142
- "name": "stdout",
143
- "output_type": "stream",
144
- "text": [
145
- "Loading LLAMA\n"
146
- ]
147
- },
148
- {
149
- "data": {
150
- "application/vnd.jupyter.widget-view+json": {
151
- "model_id": "abeac6970d914446adc1fb73f7e5b5f9",
152
- "version_major": 2,
153
- "version_minor": 0
154
- },
155
- "text/plain": [
156
- "Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
157
- ]
158
- },
159
- "metadata": {},
160
- "output_type": "display_data"
161
- },
162
- {
163
- "name": "stdout",
164
- "output_type": "stream",
165
- "text": [
166
- "Loading LLAMA Done\n",
167
- "Load BLIP2-LLM Checkpoint: /home/zhud/project/blip2/lavis/output/BLIP2/Vicuna_pretrain_stage2_cc/20230405233/checkpoint_3.pth\n"
168
- ]
169
- },
170
- {
171
- "data": {
172
- "text/html": [
173
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000\">╭─────────────────────────────── </span><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Traceback </span><span style=\"color: #bf7f7f; text-decoration-color: #bf7f7f; font-weight: bold\">(most recent call last)</span><span style=\"color: #800000; text-decoration-color: #800000\"> ────────────────────────────────╮</span>\n",
174
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">&lt;module&gt;</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">2</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
175
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
176
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1 </span>task = tasks.setup_task(cfg) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
177
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>2 model = task.build_model(cfg) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
178
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">3 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
179
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
180
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/home/zhud/project/blip2/lavis/tasks/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">base_task.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">33</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">build_model</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
181
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
182
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 30 │ │ </span>model_config = cfg.model_cfg <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
183
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 31 │ │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
184
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 32 │ │ </span>model_cls = registry.get_model_class(model_config.arch) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
185
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span> 33 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> model_cls.from_config(model_config) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
186
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 34 │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
187
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 35 │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">build_datasets</span>(<span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>, cfg): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
188
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 36 </span><span style=\"color: #bfbfbf; text-decoration-color: #bfbfbf\">│ │ </span><span style=\"color: #808000; text-decoration-color: #808000\">\"\"\"</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
189
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
190
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/home/zhud/project/blip2/lavis/models/blip2_models/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">blip2_llama.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">315</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">from_config</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
191
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
192
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">312 │ │ </span>ckpt_path = cfg.get(<span style=\"color: #808000; text-decoration-color: #808000\">\"ckpt\"</span>, <span style=\"color: #808000; text-decoration-color: #808000\">\"\"</span>) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
193
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">313 │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> ckpt_path: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
194
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">314 │ │ │ </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">print</span>(<span style=\"color: #808000; text-decoration-color: #808000\">\"Load BLIP2-LLM Checkpoint: {}\"</span>.format(ckpt_path)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
195
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>315 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ </span>ckpt = torch.load(ckpt_path, map_location=<span style=\"color: #808000; text-decoration-color: #808000\">\"cpu\"</span>) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
196
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">316 │ │ │ </span>msg = model.load_state_dict(ckpt[<span style=\"color: #808000; text-decoration-color: #808000\">'model'</span>], strict=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">False</span>) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
197
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">317 │ │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
198
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">318 │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> model <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
199
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
200
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">serialization.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">791</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">load</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
201
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
202
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 788 │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> <span style=\"color: #808000; text-decoration-color: #808000\">'encoding'</span> <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">not</span> <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> pickle_load_args.keys(): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
203
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 789 │ │ </span>pickle_load_args[<span style=\"color: #808000; text-decoration-color: #808000\">'encoding'</span>] = <span style=\"color: #808000; text-decoration-color: #808000\">'utf-8'</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
204
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 790 │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
205
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span> 791 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">with</span> _open_file_like(f, <span style=\"color: #808000; text-decoration-color: #808000\">'rb'</span>) <span style=\"color: #0000ff; text-decoration-color: #0000ff\">as</span> opened_file: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
206
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 792 │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> _is_zipfile(opened_file): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
207
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 793 │ │ │ # The zipfile reader is going to advance the current file position.</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
208
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 794 │ │ │ # If we want to actually tail call to torch.jit.load, we need to</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
209
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
210
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">serialization.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">271</span> in <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
211
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">_open_file_like</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
212
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
213
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 268 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
214
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 269 </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">_open_file_like</span>(name_or_buffer, mode): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
215
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 270 │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> _is_path(name_or_buffer): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
216
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span> 271 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> _open_file(name_or_buffer, mode) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
217
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 272 │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">else</span>: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
218
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 273 │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> <span style=\"color: #808000; text-decoration-color: #808000\">'w'</span> <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> mode: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
219
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 274 │ │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> _open_buffer_writer(name_or_buffer) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
220
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
221
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">serialization.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">252</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">__init__</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
222
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
223
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 249 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
224
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 250 </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">class</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; text-decoration: underline\">_open_file</span>(_opener): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
225
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 251 │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">__init__</span>(<span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>, name, mode): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
226
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span> 252 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">super</span>().<span style=\"color: #00ff00; text-decoration-color: #00ff00\">__init__</span>(<span style=\"color: #00ffff; text-decoration-color: #00ffff\">open</span>(name, mode)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
227
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 253 │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
228
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 254 │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">__exit__</span>(<span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>, *args): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
229
- "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 255 │ │ </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>.file_like.close() <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
230
- "<span style=\"color: #800000; text-decoration-color: #800000\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
231
- "<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">FileNotFoundError: </span><span style=\"font-weight: bold\">[</span>Errno <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span><span style=\"font-weight: bold\">]</span> No such file or directory: \n",
232
- "<span style=\"color: #008000; text-decoration-color: #008000\">'/home/zhud/project/blip2/lavis/output/BLIP2/Vicuna_pretrain_stage2_cc/20230405233/checkpoint_3.pth'</span>\n",
233
- "</pre>\n"
234
- ],
235
- "text/plain": [
236
- "\u001B[31m╭─\u001B[0m\u001B[31m──────────────────────────────\u001B[0m\u001B[31m \u001B[0m\u001B[1;31mTraceback \u001B[0m\u001B[1;2;31m(most recent call last)\u001B[0m\u001B[31m \u001B[0m\u001B[31m───────────────────────────────\u001B[0m\u001B[31m─╮\u001B[0m\n",
237
- "\u001B[31m│\u001B[0m in \u001B[92m<module>\u001B[0m:\u001B[94m2\u001B[0m \u001B[31m│\u001B[0m\n",
238
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
239
- "\u001B[31m│\u001B[0m \u001B[2m1 \u001B[0mtask = tasks.setup_task(cfg) \u001B[31m│\u001B[0m\n",
240
- "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m2 model = task.build_model(cfg) \u001B[31m│\u001B[0m\n",
241
- "\u001B[31m│\u001B[0m \u001B[2m3 \u001B[0m \u001B[31m│\u001B[0m\n",
242
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
243
- "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/project/blip2/lavis/tasks/\u001B[0m\u001B[1;33mbase_task.py\u001B[0m:\u001B[94m33\u001B[0m in \u001B[92mbuild_model\u001B[0m \u001B[31m│\u001B[0m\n",
244
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
245
- "\u001B[31m│\u001B[0m \u001B[2m 30 \u001B[0m\u001B[2m│ │ \u001B[0mmodel_config = cfg.model_cfg \u001B[31m│\u001B[0m\n",
246
- "\u001B[31m│\u001B[0m \u001B[2m 31 \u001B[0m\u001B[2m│ │ \u001B[0m \u001B[31m│\u001B[0m\n",
247
- "\u001B[31m│\u001B[0m \u001B[2m 32 \u001B[0m\u001B[2m│ │ \u001B[0mmodel_cls = registry.get_model_class(model_config.arch) \u001B[31m│\u001B[0m\n",
248
- "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m 33 \u001B[2m│ │ \u001B[0m\u001B[94mreturn\u001B[0m model_cls.from_config(model_config) \u001B[31m│\u001B[0m\n",
249
- "\u001B[31m│\u001B[0m \u001B[2m 34 \u001B[0m\u001B[2m│ \u001B[0m \u001B[31m│\u001B[0m\n",
250
- "\u001B[31m│\u001B[0m \u001B[2m 35 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mdef\u001B[0m \u001B[92mbuild_datasets\u001B[0m(\u001B[96mself\u001B[0m, cfg): \u001B[31m│\u001B[0m\n",
251
- "\u001B[31m│\u001B[0m \u001B[2m 36 \u001B[0m\u001B[2;90m│ │ \u001B[0m\u001B[33m\"\"\"\u001B[0m \u001B[31m│\u001B[0m\n",
252
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
253
- "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/project/blip2/lavis/models/blip2_models/\u001B[0m\u001B[1;33mblip2_llama.py\u001B[0m:\u001B[94m315\u001B[0m in \u001B[92mfrom_config\u001B[0m \u001B[31m│\u001B[0m\n",
254
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
255
- "\u001B[31m│\u001B[0m \u001B[2m312 \u001B[0m\u001B[2m│ │ \u001B[0mckpt_path = cfg.get(\u001B[33m\"\u001B[0m\u001B[33mckpt\u001B[0m\u001B[33m\"\u001B[0m, \u001B[33m\"\u001B[0m\u001B[33m\"\u001B[0m) \u001B[31m│\u001B[0m\n",
256
- "\u001B[31m│\u001B[0m \u001B[2m313 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[94mif\u001B[0m ckpt_path: \u001B[31m│\u001B[0m\n",
257
- "\u001B[31m│\u001B[0m \u001B[2m314 \u001B[0m\u001B[2m│ │ │ \u001B[0m\u001B[96mprint\u001B[0m(\u001B[33m\"\u001B[0m\u001B[33mLoad BLIP2-LLM Checkpoint: \u001B[0m\u001B[33m{}\u001B[0m\u001B[33m\"\u001B[0m.format(ckpt_path)) \u001B[31m│\u001B[0m\n",
258
- "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m315 \u001B[2m│ │ │ \u001B[0mckpt = torch.load(ckpt_path, map_location=\u001B[33m\"\u001B[0m\u001B[33mcpu\u001B[0m\u001B[33m\"\u001B[0m) \u001B[31m│\u001B[0m\n",
259
- "\u001B[31m│\u001B[0m \u001B[2m316 \u001B[0m\u001B[2m│ │ │ \u001B[0mmsg = model.load_state_dict(ckpt[\u001B[33m'\u001B[0m\u001B[33mmodel\u001B[0m\u001B[33m'\u001B[0m], strict=\u001B[94mFalse\u001B[0m) \u001B[31m│\u001B[0m\n",
260
- "\u001B[31m│\u001B[0m \u001B[2m317 \u001B[0m\u001B[2m│ │ \u001B[0m \u001B[31m│\u001B[0m\n",
261
- "\u001B[31m│\u001B[0m \u001B[2m318 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[94mreturn\u001B[0m model \u001B[31m│\u001B[0m\n",
262
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
263
- "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/\u001B[0m\u001B[1;33mserialization.py\u001B[0m:\u001B[94m791\u001B[0m in \u001B[92mload\u001B[0m \u001B[31m│\u001B[0m\n",
264
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
265
- "\u001B[31m│\u001B[0m \u001B[2m 788 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mif\u001B[0m \u001B[33m'\u001B[0m\u001B[33mencoding\u001B[0m\u001B[33m'\u001B[0m \u001B[95mnot\u001B[0m \u001B[95min\u001B[0m pickle_load_args.keys(): \u001B[31m│\u001B[0m\n",
266
- "\u001B[31m│\u001B[0m \u001B[2m 789 \u001B[0m\u001B[2m│ │ \u001B[0mpickle_load_args[\u001B[33m'\u001B[0m\u001B[33mencoding\u001B[0m\u001B[33m'\u001B[0m] = \u001B[33m'\u001B[0m\u001B[33mutf-8\u001B[0m\u001B[33m'\u001B[0m \u001B[31m│\u001B[0m\n",
267
- "\u001B[31m│\u001B[0m \u001B[2m 790 \u001B[0m\u001B[2m│ \u001B[0m \u001B[31m��\u001B[0m\n",
268
- "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m 791 \u001B[2m│ \u001B[0m\u001B[94mwith\u001B[0m _open_file_like(f, \u001B[33m'\u001B[0m\u001B[33mrb\u001B[0m\u001B[33m'\u001B[0m) \u001B[94mas\u001B[0m opened_file: \u001B[31m│\u001B[0m\n",
269
- "\u001B[31m│\u001B[0m \u001B[2m 792 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[94mif\u001B[0m _is_zipfile(opened_file): \u001B[31m│\u001B[0m\n",
270
- "\u001B[31m│\u001B[0m \u001B[2m 793 \u001B[0m\u001B[2m│ │ │ \u001B[0m\u001B[2m# The zipfile reader is going to advance the current file position.\u001B[0m \u001B[31m│\u001B[0m\n",
271
- "\u001B[31m│\u001B[0m \u001B[2m 794 \u001B[0m\u001B[2m│ │ │ \u001B[0m\u001B[2m# If we want to actually tail call to torch.jit.load, we need to\u001B[0m \u001B[31m│\u001B[0m\n",
272
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
273
- "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/\u001B[0m\u001B[1;33mserialization.py\u001B[0m:\u001B[94m271\u001B[0m in \u001B[31m│\u001B[0m\n",
274
- "\u001B[31m│\u001B[0m \u001B[92m_open_file_like\u001B[0m \u001B[31m│\u001B[0m\n",
275
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
276
- "\u001B[31m│\u001B[0m \u001B[2m 268 \u001B[0m \u001B[31m│\u001B[0m\n",
277
- "\u001B[31m│\u001B[0m \u001B[2m 269 \u001B[0m\u001B[94mdef\u001B[0m \u001B[92m_open_file_like\u001B[0m(name_or_buffer, mode): \u001B[31m│\u001B[0m\n",
278
- "\u001B[31m│\u001B[0m \u001B[2m 270 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mif\u001B[0m _is_path(name_or_buffer): \u001B[31m│\u001B[0m\n",
279
- "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m 271 \u001B[2m│ │ \u001B[0m\u001B[94mreturn\u001B[0m _open_file(name_or_buffer, mode) \u001B[31m│\u001B[0m\n",
280
- "\u001B[31m│\u001B[0m \u001B[2m 272 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94melse\u001B[0m: \u001B[31m│\u001B[0m\n",
281
- "\u001B[31m│\u001B[0m \u001B[2m 273 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[94mif\u001B[0m \u001B[33m'\u001B[0m\u001B[33mw\u001B[0m\u001B[33m'\u001B[0m \u001B[95min\u001B[0m mode: \u001B[31m│\u001B[0m\n",
282
- "\u001B[31m│\u001B[0m \u001B[2m 274 \u001B[0m\u001B[2m│ │ │ \u001B[0m\u001B[94mreturn\u001B[0m _open_buffer_writer(name_or_buffer) \u001B[31m│\u001B[0m\n",
283
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
284
- "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/\u001B[0m\u001B[1;33mserialization.py\u001B[0m:\u001B[94m252\u001B[0m in \u001B[92m__init__\u001B[0m \u001B[31m│\u001B[0m\n",
285
- "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n",
286
- "\u001B[31m│\u001B[0m \u001B[2m 249 \u001B[0m \u001B[31m│\u001B[0m\n",
287
- "\u001B[31m│\u001B[0m \u001B[2m 250 \u001B[0m\u001B[94mclass\u001B[0m \u001B[4;92m_open_file\u001B[0m(_opener): \u001B[31m│\u001B[0m\n",
288
- "\u001B[31m│\u001B[0m \u001B[2m 251 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mdef\u001B[0m \u001B[92m__init__\u001B[0m(\u001B[96mself\u001B[0m, name, mode): \u001B[31m│\u001B[0m\n",
289
- "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m 252 \u001B[2m│ │ \u001B[0m\u001B[96msuper\u001B[0m().\u001B[92m__init__\u001B[0m(\u001B[96mopen\u001B[0m(name, mode)) \u001B[31m│\u001B[0m\n",
290
- "\u001B[31m│\u001B[0m \u001B[2m 253 \u001B[0m\u001B[2m│ \u001B[0m \u001B[31m│\u001B[0m\n",
291
- "\u001B[31m│\u001B[0m \u001B[2m 254 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mdef\u001B[0m \u001B[92m__exit__\u001B[0m(\u001B[96mself\u001B[0m, *args): \u001B[31m│\u001B[0m\n",
292
- "\u001B[31m│\u001B[0m \u001B[2m 255 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[96mself\u001B[0m.file_like.close() \u001B[31m│\u001B[0m\n",
293
- "\u001B[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001B[0m\n",
294
- "\u001B[1;91mFileNotFoundError: \u001B[0m\u001B[1m[\u001B[0mErrno \u001B[1;36m2\u001B[0m\u001B[1m]\u001B[0m No such file or directory: \n",
295
- "\u001B[32m'/home/zhud/project/blip2/lavis/output/BLIP2/Vicuna_pretrain_stage2_cc/20230405233/checkpoint_3.pth'\u001B[0m\n"
296
- ]
297
- },
298
- "metadata": {},
299
- "output_type": "display_data"
300
- }
301
- ],
302
- "source": [
303
- "task = tasks.setup_task(cfg)\n",
304
- "model = task.build_model(cfg)"
305
- ]
306
- },
307
- {
308
- "cell_type": "code",
309
- "execution_count": 9,
310
- "id": "ba874036",
311
- "metadata": {
312
- "pycharm": {
313
- "name": "#%%\n"
314
- }
315
- },
316
- "outputs": [
317
- {
318
- "data": {
319
- "text/plain": [
320
- "'/ibex/project/c2133/vicuna'"
321
- ]
322
- },
323
- "execution_count": 9,
324
- "metadata": {},
325
- "output_type": "execute_result"
326
- }
327
- ],
328
- "source": []
329
- },
330
- {
331
- "cell_type": "markdown",
332
- "id": "bf1c4e1c",
333
- "metadata": {
334
- "pycharm": {
335
- "name": "#%% md\n"
336
- }
337
- },
338
- "source": [
339
- "### Load Checkpoint"
340
- ]
341
- },
342
- {
343
- "cell_type": "code",
344
- "execution_count": null,
345
- "id": "a2a7f2bd",
346
- "metadata": {
347
- "pycharm": {
348
- "name": "#%%\n"
349
- }
350
- },
351
- "outputs": [],
352
- "source": [
353
- "ckpt_path = '/ibex/project/c2133/vicuna_ckpt_test/Vicuna_prompt_stage2_laion/20230410145/checkpoint_4.pth'\n",
354
- "ckpt = torch.load(ckpt_path, map_location=\"cpu\")\n",
355
- "msg = model.load_state_dict(ckpt['model'], strict=False)\n",
356
- "model = model.to(device)"
357
- ]
358
- },
359
- {
360
- "cell_type": "markdown",
361
- "id": "035a495f",
362
- "metadata": {
363
- "pycharm": {
364
- "name": "#%% md\n"
365
- }
366
- },
367
- "source": [
368
- "### Example of Tokenizer"
369
- ]
370
- },
371
- {
372
- "cell_type": "code",
373
- "execution_count": 35,
374
- "id": "3426ae10",
375
- "metadata": {
376
- "pycharm": {
377
- "name": "#%%\n"
378
- }
379
- },
380
- "outputs": [],
381
- "source": [
382
- "texts = [\"A chat\", \"The assistant gives helpful\"]\n",
383
- "\n",
384
- "llama_tokens = model.llama_tokenizer(\n",
385
- " texts, \n",
386
- " return_tensors=\"pt\", \n",
387
- " padding=\"longest\",\n",
388
- " truncation=True,\n",
389
- " max_length=10).to(device)"
390
- ]
391
- },
392
- {
393
- "cell_type": "code",
394
- "execution_count": 13,
395
- "id": "376400a4",
396
- "metadata": {
397
- "pycharm": {
398
- "name": "#%%\n"
399
- }
400
- },
401
- "outputs": [],
402
- "source": [
403
- "texts = \"The assistant gives helpful\"\n",
404
- "\n",
405
- "llama_tokens = model.llama_tokenizer(\n",
406
- " texts, \n",
407
- " return_tensors=\"pt\", \n",
408
- " padding=\"longest\",\n",
409
- " truncation=True,\n",
410
- " max_length=10).to(device)"
411
- ]
412
- },
413
- {
414
- "cell_type": "code",
415
- "execution_count": 14,
416
- "id": "6988ee66",
417
- "metadata": {
418
- "pycharm": {
419
- "name": "#%%\n"
420
- }
421
- },
422
- "outputs": [
423
- {
424
- "data": {
425
- "text/plain": [
426
- "torch.Size([1, 5])"
427
- ]
428
- },
429
- "execution_count": 14,
430
- "metadata": {},
431
- "output_type": "execute_result"
432
- }
433
- ],
434
- "source": [
435
- "llama_tokens.attention_mask.shape"
436
- ]
437
- },
438
- {
439
- "cell_type": "code",
440
- "execution_count": 9,
441
- "id": "dc9e376d",
442
- "metadata": {
443
- "pycharm": {
444
- "name": "#%%\n"
445
- }
446
- },
447
- "outputs": [],
448
- "source": [
449
- "targets = llama_tokens.input_ids.masked_fill(\n",
450
- " llama_tokens.input_ids == model.llama_tokenizer.pad_token_id, -100\n",
451
- " )"
452
- ]
453
- },
454
- {
455
- "cell_type": "code",
456
- "execution_count": 10,
457
- "id": "e458fa52",
458
- "metadata": {
459
- "pycharm": {
460
- "name": "#%%\n"
461
- }
462
- },
463
- "outputs": [
464
- {
465
- "data": {
466
- "text/plain": [
467
- "torch.Size([2, 3])"
468
- ]
469
- },
470
- "execution_count": 10,
471
- "metadata": {},
472
- "output_type": "execute_result"
473
- }
474
- ],
475
- "source": [
476
- "torch.ones([targets.shape[0], targets.shape[0]+1]).shape"
477
- ]
478
- },
479
- {
480
- "cell_type": "code",
481
- "execution_count": null,
482
- "id": "24607f7a",
483
- "metadata": {
484
- "pycharm": {
485
- "name": "#%%\n"
486
- }
487
- },
488
- "outputs": [],
489
- "source": [
490
- "text = \\\n",
491
- "\"### Human: What's your name?\" \\\n",
492
- "\"### Assistant: \"\n",
493
- "\n",
494
- "\n",
495
- "llama_tokens = model.llama_tokenizer(\n",
496
- " text, \n",
497
- " return_tensors=\"pt\", \n",
498
- " ).to(device)"
499
- ]
500
- },
501
- {
502
- "cell_type": "markdown",
503
- "id": "5e69d3e1",
504
- "metadata": {
505
- "pycharm": {
506
- "name": "#%% md\n"
507
- }
508
- },
509
- "source": [
510
- "### Example of Emb Input"
511
- ]
512
- },
513
- {
514
- "cell_type": "code",
515
- "execution_count": 188,
516
- "id": "205b092f",
517
- "metadata": {
518
- "pycharm": {
519
- "name": "#%%\n"
520
- }
521
- },
522
- "outputs": [
523
- {
524
- "name": "stdout",
525
- "output_type": "stream",
526
- "text": [
527
- "<unk>​\n",
528
- "\n",
529
- "I'm sorry, I am an AI language model and do not have a physical form or a name. My purpose is to assist you with any questions or tasks you may have to the best of my ability. Is there anything specific you would like help with?\n",
530
- "###\n"
531
- ]
532
- }
533
- ],
534
- "source": [
535
- "inputs_embeds = model.llama_model.model.embed_tokens(llama_tokens.input_ids)\n",
536
- "outputs = model.llama_model.generate(\n",
537
- " inputs_embeds=inputs_embeds,\n",
538
- " query_embeds=None,\n",
539
- " attention_mask=llama_tokens.attention_mask,\n",
540
- " max_new_tokens=500,\n",
541
- " stopping_criteria=stopping_criteria,\n",
542
- " )\n",
543
- "output_text = model.llama_tokenizer.decode(outputs[0])\n",
544
- "print(output_text)"
545
- ]
546
- },
547
- {
548
- "cell_type": "code",
549
- "execution_count": 189,
550
- "id": "561b42f5",
551
- "metadata": {
552
- "pycharm": {
553
- "name": "#%%\n"
554
- }
555
- },
556
- "outputs": [
557
- {
558
- "data": {
559
- "text/plain": [
560
- "torch.Size([1, 16, 5120])"
561
- ]
562
- },
563
- "execution_count": 189,
564
- "metadata": {},
565
- "output_type": "execute_result"
566
- }
567
- ],
568
- "source": [
569
- "inputs_embeds.shape"
570
- ]
571
- },
572
- {
573
- "cell_type": "markdown",
574
- "id": "a1694ad6",
575
- "metadata": {
576
- "pycharm": {
577
- "name": "#%% md\n"
578
- }
579
- },
580
- "source": [
581
- "### Example of ID Input"
582
- ]
583
- },
584
- {
585
- "cell_type": "code",
586
- "execution_count": null,
587
- "id": "c1dc7841",
588
- "metadata": {
589
- "pycharm": {
590
- "name": "#%%\n"
591
- }
592
- },
593
- "outputs": [],
594
- "source": [
595
- "outputs = model.llama_model.generate(\n",
596
- " input_ids=llama_tokens.input_ids,\n",
597
- " query_embeds=None,\n",
598
- " attention_mask=llama_tokens.attention_mask,\n",
599
- " max_new_tokens=500,\n",
600
- " stopping_criteria=stopping_criteria,\n",
601
- " )\n",
602
- "output_text = model.llama_tokenizer.decode(outputs[0])\n",
603
- "print(output_text)"
604
- ]
605
- },
606
- {
607
- "cell_type": "markdown",
608
- "id": "19dd1f9d",
609
- "metadata": {
610
- "pycharm": {
611
- "name": "#%% md\n"
612
- }
613
- },
614
- "source": []
615
- },
616
- {
617
- "cell_type": "markdown",
618
- "id": "468ac97e",
619
- "metadata": {
620
- "pycharm": {
621
- "name": "#%% md\n"
622
- }
623
- },
624
- "source": [
625
- "### Example of Mixed Input"
626
- ]
627
- },
628
- {
629
- "cell_type": "code",
630
- "execution_count": 47,
631
- "id": "4af3a9bf",
632
- "metadata": {
633
- "pycharm": {
634
- "name": "#%%\n"
635
- }
636
- },
637
- "outputs": [],
638
- "source": [
639
- "ckpt_path = '/home/zhud/project/blip2/lavis/output/BLIP2/Vicuna_pretrain_stage2_cc/20230408015/checkpoint_2.pth'\n",
640
- "ckpt = torch.load(ckpt_path, map_location=\"cpu\")\n",
641
- "msg = model.load_state_dict(ckpt['model'], strict=False)\n",
642
- "model = model.to(device)"
643
- ]
644
- },
645
- {
646
- "cell_type": "code",
647
- "execution_count": 48,
648
- "id": "c3148611",
649
- "metadata": {
650
- "pycharm": {
651
- "name": "#%%\n"
652
- }
653
- },
654
- "outputs": [],
655
- "source": [
656
- "# Load the image using PIL\n",
657
- "image = Image.open('test_img5.jpg').convert('RGB')\n",
658
- "image = vis_processor(image).unsqueeze(0).to(device)\n",
659
- "inputs_llama, atts_llama = model.encode_img(image)"
660
- ]
661
- },
662
- {
663
- "cell_type": "code",
664
- "execution_count": 53,
665
- "id": "07b82707",
666
- "metadata": {
667
- "pycharm": {
668
- "name": "#%%\n"
669
- }
670
- },
671
- "outputs": [],
672
- "source": [
673
- "text = \\\n",
674
- "\"A chat between a curious human and an artificial intelligence assistant. \" \\\n",
675
- "\"The assistant gives helpful, detailed, and polite answers to the human's questions. \"\\\n",
676
- "\"Human may ask questions related to a given image. \" \\\n",
677
- "\"The image will be wrapped as <Img> IMAGE_CONTENT </Img> \" \\\n",
678
- "\"### Human: <Img>To_Split</Img> \" \\\n",
679
- "\"### Assistant: Received the image. \" \\\n",
680
- "\"### Human: Describe the image in detail. Say everthing you see. Describe all the things.\" \\\n",
681
- "\"### Assistant: \"\n",
682
- "\n",
683
- "\n",
684
- "text = \\\n",
685
- "\"A chat between a curious human and an artificial intelligence assistant. \" \\\n",
686
- "\"The assistant gives helpful, detailed, and polite answers to the human's questions. \"\\\n",
687
- "\"Human may ask questions related to a given image. \" \\\n",
688
- "\"The image will be wrapped as <Img> IMAGE_CONTENT </Img> \" \\\n",
689
- "\"### Human: Describe the image in detail. Say everthing you see. <Img>To_Split</Img> \" \\\n",
690
- "\"### Assistant: \"\n",
691
- "\n",
692
- "text = \\\n",
693
- "\"### Human: Describe the image in detail. Say everthing you see. <Img>To_Split</Img> \" \\\n",
694
- "\"### Assistant: \"\n",
695
- "\n",
696
- "\n",
697
- "\n",
698
- "# text = \\\n",
699
- "# \"A chat between a curious human and an artificial intelligence assistant. \" \\\n",
700
- "# \"The assistant gives helpful, detailed, and polite answers to the human's questions. \"\\\n",
701
- "# \"Human may ask questions related to a given image. \" \\\n",
702
- "# \"The image will be wrapped as <Img> IMAGE_CONTENT </Img> \" \\\n",
703
- "# \"### Human: <Img>To_Split</Img> \" \\\n",
704
- "# \"### Assistant: Received the image. \" \\\n",
705
- "# \"### Human: This is a draft of a website. Give me the html code to write this website. \" \\\n",
706
- "# \"Btw, you need to come up with some jokes in the website to fill the placeholders. \" \\\n",
707
- "# \"Also, make the website colorful and vivid. \" \\\n",
708
- "# \"### Assistant: \"\n",
709
- "\n",
710
- "\n",
711
- "# text = \\\n",
712
- "# \"Return what the human says. \" \\\n",
713
- "# \"### Human: There is a big elephant in the sky. \" \\\n",
714
- "# \"### Assistant: There is a big elephant in the sky. \" \\\n",
715
- "# \"### Human: fdjlks klcznv_l1 \" \\\n",
716
- "# \"### Assistant: fdjlks klcznv_l1 \" \\\n",
717
- "# \"### Human: To_Split \" \\\n",
718
- "# \"### Assistant: \"\n",
719
- "\n",
720
- "\n",
721
- "text_1, text_2 = text.split('To_Split')\n",
722
- "\n",
723
- "text_1_tokens = model.llama_tokenizer(text_1, return_tensors=\"pt\").to(device)\n",
724
- "text_2_tokens = model.llama_tokenizer(text_2, return_tensors=\"pt\", add_special_tokens=False).to(device)\n",
725
- "text_1_emb = model.llama_model.model.embed_tokens(text_1_tokens.input_ids)\n",
726
- "text_2_emb = model.llama_model.model.embed_tokens(text_2_tokens.input_ids)"
727
- ]
728
- },
729
- {
730
- "cell_type": "code",
731
- "execution_count": 54,
732
- "id": "136b9e97",
733
- "metadata": {
734
- "pycharm": {
735
- "name": "#%%\n"
736
- }
737
- },
738
- "outputs": [
739
- {
740
- "name": "stdout",
741
- "output_type": "stream",
742
- "text": [
743
- "<unk>\n",
744
- "\n",
745
- "The image shows a small bird perched on a tree stump, with a camera lens in the background\n",
746
- "\n",
747
- "The bird is a small bird, with a bright yellow beak and black feathers. It is perched on a tree stump, with its wings spread out and its beak open. The bird is looking to the left, as if it is about to take off.\n",
748
- "\n",
749
- "The camera lens in the background is a large, black lens with a silver ring around the front. The lens is attached to a camera, which is not visible in the image. The lens is pointed at the bird, with the camera's viewfinder showing the bird in the center of the frame.\n",
750
- "\n",
751
- "The background of the image is a forest, with trees and foliage visible in the distance. The trees are covered in leaves, and there is a thick layer of mist or fog in the air, which gives the image a dreamy, ethereal quality.\n",
752
- "\n",
753
- "The lighting in the image is soft and diffused, with the sun shining through the trees and casting a warm, golden light on the bird and the tree stump. The lighting creates deep shadows in the forest, which add to the sense of mystery and wonder in the image.\n",
754
- "\n",
755
- "The overall effect of the image is one of peacefulness and tranquility, with the bird and the forest creating a sense of calm and serenity. The image is beautifully composed, with the bird and the camera lens creating a visual balance that draws the viewer's eye to the center of the frame.\n",
756
- "###\n"
757
- ]
758
- }
759
- ],
760
- "source": [
761
- "outputs = model.llama_model.generate(\n",
762
- " inputs_embeds=torch.concat([text_1_emb, inputs_llama, text_2_emb], dim=1),\n",
763
- " query_embeds=None,\n",
764
- " attention_mask=torch.concat([text_1_tokens.attention_mask, atts_llama, text_2_tokens.attention_mask], dim=1),\n",
765
- " max_new_tokens=600,\n",
766
- " stopping_criteria=stopping_criteria,\n",
767
- " )\n",
768
- "output_text = model.llama_tokenizer.decode(outputs[0])\n",
769
- "print(output_text)"
770
- ]
771
- },
772
- {
773
- "cell_type": "code",
774
- "execution_count": 83,
775
- "id": "54cc3d4a",
776
- "metadata": {
777
- "pycharm": {
778
- "name": "#%%\n"
779
- }
780
- },
781
- "outputs": [],
782
- "source": [
783
- "with open('lavis/prompts/image_caption.txt', 'r') as f:\n",
784
- " prompts = f.read().splitlines()"
785
- ]
786
- },
787
- {
788
- "cell_type": "code",
789
- "execution_count": 92,
790
- "id": "f52cd85c",
791
- "metadata": {
792
- "pycharm": {
793
- "name": "#%%\n"
794
- }
795
- },
796
- "outputs": [],
797
- "source": [
798
- "prompt_token = model.llama_tokenizer(prompts, return_tensors=\"pt\", padding=\"longest\",)"
799
- ]
800
- },
801
- {
802
- "cell_type": "code",
803
- "execution_count": 103,
804
- "id": "4b0cf1d0",
805
- "metadata": {
806
- "pycharm": {
807
- "name": "#%%\n"
808
- }
809
- },
810
- "outputs": [
811
- {
812
- "name": "stdout",
813
- "output_type": "stream",
814
- "text": [
815
- "[(15, 6), (16, 11), (17, 17), (18, 17), (19, 27), (20, 18), (21, 21), (22, 4), (23, 6), (24, 2)]\n"
816
- ]
817
- }
818
- ],
819
- "source": [
820
- "\n",
821
- "\n",
822
- "my_list = prompt_token.attention_mask.sum(1).numpy()\n",
823
- "counts = {}\n",
824
- "\n",
825
- "for element in my_list:\n",
826
- " if element in counts:\n",
827
- " counts[element] += 1\n",
828
- " else:\n",
829
- " counts[element] = 1\n",
830
- "\n",
831
- "print(sorted(counts.items(), key=lambda item: item[0]))"
832
- ]
833
- },
834
- {
835
- "cell_type": "code",
836
- "execution_count": 58,
837
- "id": "f7919e93",
838
- "metadata": {
839
- "pycharm": {
840
- "name": "#%%\n"
841
- }
842
- },
843
- "outputs": [
844
- {
845
- "name": "stdout",
846
- "output_type": "stream",
847
- "text": [
848
- "[1, 2, 1, 2, 1, 2]\n"
849
- ]
850
- }
851
- ],
852
- "source": [
853
- "a,b = [1,1,1], [2,2,2]\n",
854
- "c = [i for pair in zip(a,b) for i in pair]\n",
855
- "print(c)"
856
- ]
857
- },
858
- {
859
- "cell_type": "markdown",
860
- "id": "3c64a037",
861
- "metadata": {
862
- "pycharm": {
863
- "name": "#%% md\n"
864
- }
865
- },
866
- "source": [
867
- "### Example of Image Input"
868
- ]
869
- },
870
- {
871
- "cell_type": "code",
872
- "execution_count": 67,
873
- "id": "87164578",
874
- "metadata": {
875
- "pycharm": {
876
- "name": "#%%\n"
877
- }
878
- },
879
- "outputs": [
880
- {
881
- "name": "stdout",
882
- "output_type": "stream",
883
- "text": [
884
- "<unk>a bird eating from a bird feeder\n",
885
- "\n",
886
- "bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird\n",
887
- "bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird\n",
888
- "bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird\n",
889
- "bird feeder, bird feeder, bird feeder\n"
890
- ]
891
- }
892
- ],
893
- "source": [
894
- "inputs_embeds = model.llama_model.model.embed_tokens(llama_tokens.input_ids)\n",
895
- "bos_embeds = model.llama_model.model.embed_tokens(torch.tensor(model.llama_tokenizer.bos_token_id, device=device))[None, None]\n",
896
- "outputs = model.llama_model.generate(\n",
897
- " inputs_embeds=torch.concat([bos_embeds, inputs_llama], dim=1),\n",
898
- " query_embeds=None,\n",
899
- " attention_mask=torch.concat([atts_llama[:, :1], atts_llama], dim=1),\n",
900
- " max_new_tokens=100,\n",
901
- " stopping_criteria=stopping_criteria,\n",
902
- " )\n",
903
- "output_text = model.llama_tokenizer.decode(outputs[0])\n",
904
- "print(output_text)"
905
- ]
906
- }
907
- ],
908
- "metadata": {
909
- "kernelspec": {
910
- "display_name": "eye",
911
- "language": "python",
912
- "name": "eye"
913
- },
914
- "language_info": {
915
- "codemirror_mode": {
916
- "name": "ipython",
917
- "version": 3
918
- },
919
- "file_extension": ".py",
920
- "mimetype": "text/x-python",
921
- "name": "python",
922
- "nbconvert_exporter": "python",
923
- "pygments_lexer": "ipython3",
924
- "version": "3.9.16"
925
- }
926
- },
927
- "nbformat": 4,
928
- "nbformat_minor": 5
929
- }