Upload 5 files
Browse files- .gitattributes +1 -0
- README.md +7 -6
- config.json +4 -1
- demo.ipynb +217 -0
- example_image/aki_compressed.jpg +3 -0
- model.safetensors +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
example_image/aki_compressed.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -7,13 +7,15 @@ pipeline_tag: image-text-to-text
|
|
7 |
|
8 |
# AKI Model Card
|
9 |
`AKI` is the official checkpoint for the paper "[Seeing is Understanding: Unlocking Causal Attention into Modality-Mutual Attention for Multimodal LLMs](https://arxiv.org/abs/2503.02597)".
|
10 |
-
AKI is a multimodal foundation model that unlocks causal attention in the LLM into modality-mutual attention (MMA), which enables the earlier modality (images) to incorporate information from the latter modality (text) without introducing additional parameters and increasing training time.
|
11 |
|
12 |
## Model Details
|
13 |
### Model Descriptions
|
14 |
- Vision Encoder: [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384)
|
15 |
-
- Vision-Language Connector: Perceiver Resampler
|
16 |
- Language Decoder (LLM): [microsoft/Phi-3.5-mini-instruct](https://huggingface.co/microsoft/Phi-3.5-mini-instruct)
|
|
|
|
|
17 |
|
18 |
### Model Sources
|
19 |
- Repository: [GitHub](https://github.com/sony/aki)
|
@@ -35,11 +37,10 @@ Describe the scene of this image.
|
|
35 |
> : The image captures a beautiful autumn day in a park, with a pathway covered in a vibrant carpet of fallen leaves. The leaves are in various shades of red, orange, yellow, and brown, creating a warm and colorful atmosphere. The path is lined with trees displaying beautiful autumn foliage, adding to the picturesque setting. ...
|
36 |
|
37 |
### Inference Example
|
38 |
-
|
|
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
```
|
43 |
|
44 |
## Evaluation Results
|
45 |
### Main Comparisons with the Same Configurations (Table 1)
|
|
|
7 |
|
8 |
# AKI Model Card
|
9 |
`AKI` is the official checkpoint for the paper "[Seeing is Understanding: Unlocking Causal Attention into Modality-Mutual Attention for Multimodal LLMs](https://arxiv.org/abs/2503.02597)".
|
10 |
+
AKI is a multimodal foundation model that unlocks causal attention in the LLM into modality-mutual attention (MMA), which enables the earlier modality (images) to incorporate information from the latter modality (text) for addressing vision-language misalignment without introducing additional parameters and increasing training time.
|
11 |
|
12 |
## Model Details
|
13 |
### Model Descriptions
|
14 |
- Vision Encoder: [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384)
|
15 |
+
- Vision-Language Connector: [Perceiver Resampler](https://arxiv.org/abs/2204.14198)
|
16 |
- Language Decoder (LLM): [microsoft/Phi-3.5-mini-instruct](https://huggingface.co/microsoft/Phi-3.5-mini-instruct)
|
17 |
+
- Pretraining Datasets: [Blip3-kale](https://huggingface.co/datasets/Salesforce/blip3-kale) and [Blip3-OCR-200m](https://huggingface.co/datasets/Salesforce/blip3-ocr-200m)
|
18 |
+
- SFT Datasets: VQAv2, GQA, VSR, OCRVQA, A-OKVQA, ScienceQA, RefCOCO, RefCOCOg, RefCOCO+, VisualGnome, LLaVA-150k
|
19 |
|
20 |
### Model Sources
|
21 |
- Repository: [GitHub](https://github.com/sony/aki)
|
|
|
37 |
> : The image captures a beautiful autumn day in a park, with a pathway covered in a vibrant carpet of fallen leaves. The leaves are in various shades of red, orange, yellow, and brown, creating a warm and colorful atmosphere. The path is lined with trees displaying beautiful autumn foliage, adding to the picturesque setting. ...
|
38 |
|
39 |
### Inference Example
|
40 |
+
Please refer to the [notebook](demo.ipynb) for the zero-shot inference.
|
41 |
+
To build a local demo website, please refer to [local_demo.py](https://github.com/sony/aki/blob/main/codes/open_flamingo/local_demo.py).
|
42 |
|
43 |
+
> For the training scripts, please refer to the [GitHub repo](https://github.com/sony/aki).
|
|
|
|
|
44 |
|
45 |
## Evaluation Results
|
46 |
### Main Comparisons with the Same Configurations (Table 1)
|
config.json
CHANGED
@@ -7,5 +7,8 @@
|
|
7 |
"num_vision_tokens": 144,
|
8 |
"pad_token_id": 32011,
|
9 |
"tokenizer": null,
|
10 |
-
"vision_encoder_path": "google/siglip-so400m-patch14-384"
|
|
|
|
|
|
|
11 |
}
|
|
|
7 |
"num_vision_tokens": 144,
|
8 |
"pad_token_id": 32011,
|
9 |
"tokenizer": null,
|
10 |
+
"vision_encoder_path": "google/siglip-so400m-patch14-384",
|
11 |
+
"n_px": 384,
|
12 |
+
"norm_mean": 0.5,
|
13 |
+
"norm_std": 0.5
|
14 |
}
|
demo.ipynb
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/home/Weiyao.Wang/virtualenvs/Kanzo/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"from src.aki import AKI\n",
|
19 |
+
"from transformers import AutoTokenizer, AutoConfig\n",
|
20 |
+
"from torchvision.transforms import Compose, Resize, Lambda, ToTensor, Normalize\n",
|
21 |
+
"from PIL import Image\n",
|
22 |
+
"try:\n",
|
23 |
+
" from torchvision.transforms import InterpolationMode\n",
|
24 |
+
" BICUBIC = InterpolationMode.BICUBIC\n",
|
25 |
+
"except ImportError:\n",
|
26 |
+
" BICUBIC = Image.BICUBIC"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": 2,
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"def apply_prompt_template(query: str) -> str:\n",
|
36 |
+
" SYSTEM_BASE = \"A chat between a curious user and an artificial intelligence assistant.\"\n",
|
37 |
+
" SYSTEM_DETAIL = \"The assistant gives helpful, detailed, and polite answers to the user's questions.\"\n",
|
38 |
+
" SYSTEM_MESSAGE = SYSTEM_BASE + \" \" + SYSTEM_DETAIL\n",
|
39 |
+
" SYSTEM_MESSAGE_ROLE = '<|system|>' + '\\n' + SYSTEM_MESSAGE + '<|end|>\\n'\n",
|
40 |
+
"\n",
|
41 |
+
" s = (\n",
|
42 |
+
" f'<s> {SYSTEM_MESSAGE_ROLE}'\n",
|
43 |
+
" f'<|user|>\\n<image>\\n{query}<|end|>\\n<|assistant|>\\n'\n",
|
44 |
+
" )\n",
|
45 |
+
" return s\n",
|
46 |
+
"\n",
|
47 |
+
"\n",
|
48 |
+
"def load_model_and_processor(ckpt_path, config):\n",
|
49 |
+
" n_px = getattr(config, \"n_px\", 384)\n",
|
50 |
+
" norm_mean = getattr(config, \"norm_mean\", 0.5)\n",
|
51 |
+
" norm_std = getattr(config, \"norm_std\", 0.5)\n",
|
52 |
+
"\n",
|
53 |
+
" # replace GenerationMixin to modify attention mask handling\n",
|
54 |
+
" from transformers.generation.utils import GenerationMixin\n",
|
55 |
+
" from open_flamingo import _aki_update_model_kwargs_for_generation\n",
|
56 |
+
" GenerationMixin._update_model_kwargs_for_generation = _aki_update_model_kwargs_for_generation\n",
|
57 |
+
" \n",
|
58 |
+
" tokenizer = AutoTokenizer.from_pretrained(ckpt_path)\n",
|
59 |
+
" model = AKI.from_pretrained(ckpt_path, tokenizer=tokenizer)\n",
|
60 |
+
" image_processor = Compose([\n",
|
61 |
+
" Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC, antialias=True),\n",
|
62 |
+
" Lambda(lambda x: x.convert('RGB')),\n",
|
63 |
+
" ToTensor(),\n",
|
64 |
+
" Normalize(mean=(norm_mean, norm_mean, norm_mean), std=(norm_std, norm_std, norm_std))\n",
|
65 |
+
" ])\n",
|
66 |
+
"\n",
|
67 |
+
" model.eval().cuda()\n",
|
68 |
+
" print(\"Model initialization is done.\")\n",
|
69 |
+
" return model, image_processor, tokenizer"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": 3,
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [
|
77 |
+
{
|
78 |
+
"name": "stderr",
|
79 |
+
"output_type": "stream",
|
80 |
+
"text": [
|
81 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
82 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
83 |
+
"`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.\n",
|
84 |
+
"Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.\n",
|
85 |
+
"Loading checkpoint shards: 100%|ββββββββββ| 2/2 [00:03<00:00, 1.52s/it]\n"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"name": "stdout",
|
90 |
+
"output_type": "stream",
|
91 |
+
"text": [
|
92 |
+
"Loading weights from local directory\n",
|
93 |
+
"Model initialization is done.\n"
|
94 |
+
]
|
95 |
+
}
|
96 |
+
],
|
97 |
+
"source": [
|
98 |
+
"model_path = \"/home/Weiyao.Wang/projects/Multimodal-Foundation-Models/codes/open_flamingo/aki-phi3.5-mini-4b\"\n",
|
99 |
+
"config = AutoConfig.from_pretrained(model_path)\n",
|
100 |
+
"# Load model, image_processor, tokenizer\n",
|
101 |
+
"model, image_processor, tokenizer = load_model_and_processor(model_path, config=config)"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": 7,
|
107 |
+
"metadata": {},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"def process_input(image_path: str, text_input: str) -> str:\n",
|
111 |
+
" \"\"\"\n",
|
112 |
+
" Processes the input image and text prompt to generate a response from the AKI model.\n",
|
113 |
+
" \n",
|
114 |
+
" Args:\n",
|
115 |
+
" image_path (str): The path of the image.\n",
|
116 |
+
" text_input (str): The text prompt to accompany the image.\n",
|
117 |
+
" \n",
|
118 |
+
" Returns:\n",
|
119 |
+
" str: The generated text from the model.\n",
|
120 |
+
" \"\"\"\n",
|
121 |
+
"\n",
|
122 |
+
" image = Image.open(image_path).convert('RGB')\n",
|
123 |
+
" \n",
|
124 |
+
" # tokenize text input with the chat template\n",
|
125 |
+
" prompt = apply_prompt_template(text_input)\n",
|
126 |
+
" lang_x = tokenizer([prompt], return_tensors='pt', add_special_tokens=False)\n",
|
127 |
+
"\n",
|
128 |
+
" print(\"Prompt:\", prompt)\n",
|
129 |
+
" \n",
|
130 |
+
" # Preprocess inputs for the model\n",
|
131 |
+
" vision_x = image_processor(image)[None, None, None, ...].cuda()\n",
|
132 |
+
"\n",
|
133 |
+
" generation_kwargs = {\n",
|
134 |
+
" 'max_new_tokens': 256,\n",
|
135 |
+
" 'do_sample': False,\n",
|
136 |
+
" }\n",
|
137 |
+
" \n",
|
138 |
+
" # Generate the model's output based on the inputs\n",
|
139 |
+
" output = model.generate(\n",
|
140 |
+
" vision_x=vision_x.cuda(),\n",
|
141 |
+
" lang_x=lang_x['input_ids'].cuda(),\n",
|
142 |
+
" attention_mask=lang_x['attention_mask'].cuda(),\n",
|
143 |
+
" **generation_kwargs\n",
|
144 |
+
" )\n",
|
145 |
+
" \n",
|
146 |
+
" # Decode the generated output into readable text\n",
|
147 |
+
" generated_text = tokenizer.decode(output[0], skip_special_tokens=True)\n",
|
148 |
+
" \n",
|
149 |
+
" return generated_text"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"execution_count": 10,
|
155 |
+
"metadata": {},
|
156 |
+
"outputs": [
|
157 |
+
{
|
158 |
+
"name": "stdout",
|
159 |
+
"output_type": "stream",
|
160 |
+
"text": [
|
161 |
+
"Prompt: <s> <|system|>\n",
|
162 |
+
"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n",
|
163 |
+
"<|user|>\n",
|
164 |
+
"<image>\n",
|
165 |
+
"Describe the scene of this image.<|end|>\n",
|
166 |
+
"<|assistant|>\n",
|
167 |
+
"\n"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"name": "stdout",
|
172 |
+
"output_type": "stream",
|
173 |
+
"text": [
|
174 |
+
"Response:\n",
|
175 |
+
" The image captures a beautiful autumn day in a park, with a pathway covered in a vibrant carpet of fallen leaves. The leaves are in various shades of red, orange, yellow, and brown, creating a warm and colorful atmosphere. The path is lined with trees displaying beautiful autumn foliage, adding to the picturesque setting.\n",
|
176 |
+
"\n",
|
177 |
+
"A few benches are scattered along the path, providing visitors with a place to sit and enjoy the view of the falling leaves and the surrounding trees. The overall scene is serene and inviting, making it an ideal spot for relaxation and appreciating the beauty of the season.\n"
|
178 |
+
]
|
179 |
+
}
|
180 |
+
],
|
181 |
+
"source": [
|
182 |
+
"image_path = \"example_image/aki_compressed.jpg\"\n",
|
183 |
+
"text_input = \"Describe the scene of this image.\"\n",
|
184 |
+
"response = process_input(image_path, text_input)\n",
|
185 |
+
"print(\"Response:\\n\", response)"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "code",
|
190 |
+
"execution_count": null,
|
191 |
+
"metadata": {},
|
192 |
+
"outputs": [],
|
193 |
+
"source": []
|
194 |
+
}
|
195 |
+
],
|
196 |
+
"metadata": {
|
197 |
+
"kernelspec": {
|
198 |
+
"display_name": "Kanzo",
|
199 |
+
"language": "python",
|
200 |
+
"name": "python3"
|
201 |
+
},
|
202 |
+
"language_info": {
|
203 |
+
"codemirror_mode": {
|
204 |
+
"name": "ipython",
|
205 |
+
"version": 3
|
206 |
+
},
|
207 |
+
"file_extension": ".py",
|
208 |
+
"mimetype": "text/x-python",
|
209 |
+
"name": "python",
|
210 |
+
"nbconvert_exporter": "python",
|
211 |
+
"pygments_lexer": "ipython3",
|
212 |
+
"version": "3.12.6"
|
213 |
+
}
|
214 |
+
},
|
215 |
+
"nbformat": 4,
|
216 |
+
"nbformat_minor": 2
|
217 |
+
}
|
example_image/aki_compressed.jpg
ADDED
![]() |
Git LFS Details
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1154b183974b8ab07bd8e5f36a093cb50a37751825caacd648b0acd92e5cfc4a
|
3 |
+
size 17323922632
|