Sony
/

Image-Text-to-Text
Safetensors
English
conversational
SwyWang commited on
Commit
7eb0198
·
verified ·
1 Parent(s): ed1788a

Upload 8 files

Browse files
Files changed (8) hide show
  1. README.md +2 -1
  2. demo.ipynb +0 -0
  3. src/__init__.py +0 -0
  4. src/aki.py +226 -0
  5. src/aki_generation.py +86 -0
  6. src/helpers.py +613 -0
  7. src/utils.py +108 -0
  8. src/vlm.py +777 -0
README.md CHANGED
@@ -3,6 +3,7 @@ license: cc-by-nc-4.0
3
  language:
4
  - en
5
  pipeline_tag: image-text-to-text
 
6
  ---
7
 
8
  # AKI Model Card
@@ -34,7 +35,7 @@ Describe the scene of this image.
34
  <|end|>
35
  <|assistant|>
36
  ```
37
- > : The image captures a beautiful autumn day in a park, with a pathway covered in a vibrant carpet of fallen leaves. The leaves are in various shades of red, orange, yellow, and brown, creating a warm and colorful atmosphere. The path is lined with trees displaying beautiful autumn foliage, adding to the picturesque setting. ...
38
 
39
  ### Inference Example
40
  Please refer to the [notebook](demo.ipynb) for the zero-shot inference.
 
3
  language:
4
  - en
5
  pipeline_tag: image-text-to-text
6
+ arxiv: 2503.02597
7
  ---
8
 
9
  # AKI Model Card
 
35
  <|end|>
36
  <|assistant|>
37
  ```
38
+ > The image captures a beautiful autumn day in a park, with a pathway covered in a vibrant carpet of fallen leaves. The leaves are in various shades of red, orange, yellow, and brown, creating a warm and colorful atmosphere. The path is lined with trees displaying beautiful autumn foliage, adding to the picturesque setting. ...
39
 
40
  ### Inference Example
41
  Please refer to the [notebook](demo.ipynb) for the zero-shot inference.
demo.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
src/__init__.py ADDED
File without changes
src/aki.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+ from typing import List, Optional, Tuple, Union
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
7
+
8
+ from .helpers import PerceiverResampler
9
+ from .vlm import VLMWithLanguageStream
10
+
11
+ class AKI(VLMWithLanguageStream, PyTorchModelHubMixin):
12
+ def __init__(
13
+ self,
14
+ vision_encoder_path: str,
15
+ lang_model_path: str,
16
+ pad_token_id: int,
17
+ initial_tokenizer_len: Optional[int] = None,
18
+ tokenizer: Optional[AutoTokenizer] = None,
19
+ decoder_layers_attr_name: str = None,
20
+ gradient_checkpointing: bool = False,
21
+ base_img_size: Optional[int] = None,
22
+ num_vision_tokens: int = 144,
23
+ ):
24
+ """
25
+ Args:
26
+ vision_encoder (nn.Module): HF CLIPModel
27
+ lang_encoder (nn.Module): HF causal language model
28
+ vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
29
+ initial_tokenizer_len (int): size of the tokenizer vocab
30
+ padding_token_id (int): id of the padding token. None if no padding token; then a padding token
31
+ will be inserted into self.special_tokens, which factory.py fills after creating new tokens
32
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
33
+ gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
34
+ """
35
+
36
+ # load the vision model
37
+ model = AutoModel.from_pretrained(vision_encoder_path)
38
+ vision_encoder = model.vision_model
39
+ vis_feature_dim = vision_encoder.config.hidden_size
40
+
41
+ # load the language model
42
+ lang_model = AutoModelForCausalLM.from_pretrained(
43
+ lang_model_path,
44
+ local_files_only=False,
45
+ trust_remote_code=True,
46
+ )
47
+
48
+ self._special_tokens = {
49
+ "media_token": "<image>",
50
+ "end_of_trunk_token": "<|endofchunk|>",
51
+ }
52
+ lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
53
+ super().__init__(
54
+ vision_encoder=vision_encoder,
55
+ vision_tokenizer=PerceiverResampler(
56
+ dim=vis_feature_dim, dim_inner=lang_embedding_dim,
57
+ num_latents=num_vision_tokens,
58
+ ),
59
+ lang_model=lang_model,
60
+ initial_tokenizer_len=initial_tokenizer_len,
61
+ gradient_checkpointing=gradient_checkpointing,
62
+ base_img_size=base_img_size,
63
+ decoder_layers_attr_name=decoder_layers_attr_name,
64
+ pad_token_id=pad_token_id,
65
+ )
66
+
67
+ if tokenizer is not None:
68
+ self.lang_model.config.vocab_size = len(tokenizer)
69
+ self.set_special_token_ids(
70
+ {
71
+ v: tokenizer.convert_tokens_to_ids(v)
72
+ for v in self.special_tokens.values()
73
+ }
74
+ )
75
+
76
+ def set_trainable(self):
77
+ """
78
+ Unfreeze everything except the vision_encoder
79
+ """
80
+ self.requires_grad_(True)
81
+ self.vision_encoder.requires_grad_(False)
82
+
83
+ def forward(
84
+ self,
85
+ vision_x: Optional[torch.Tensor],
86
+ lang_x: torch.Tensor,
87
+ attention_mask: Optional[torch.Tensor] = None,
88
+ labels: Optional[torch.Tensor] = None,
89
+ past_key_values: Optional[
90
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
91
+ ] = None,
92
+ past_media_locations: Optional[torch.Tensor] = None,
93
+ past_vision_tokens: Optional[torch.Tensor] = None,
94
+ use_cache: Optional[bool] = False,
95
+ **kwargs,
96
+ ):
97
+ """
98
+ Args:
99
+ vision_x: Vision input
100
+ shape (B, T_img, F, C, H, W) with F=1
101
+ only F = 1 is supported (single-frame videos)
102
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
103
+ only the first number of media tokens in lang_x are used
104
+ lang_x: Language input ids, with media tokens denoting where
105
+ visual media should be inserted.
106
+ shape (B, T_txt)
107
+ attention_mask: Attention mask. Defaults to None.
108
+ labels: Labels. Defaults to None.
109
+ shape (B, T_txt)
110
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
111
+ list of length = number of decoder layers in the LM
112
+ exact implementation depends on LM, see Hugging Face docs
113
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
114
+ shape (B, T_txt)
115
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
116
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
117
+ If True, includes key_values, media_locations, and vision_tokens in the output.
118
+ """
119
+ assert not (past_vision_tokens is None) ^ (
120
+ past_media_locations is None
121
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
122
+
123
+ # convert pixels to vision tokens
124
+ vision_attention_mask = None
125
+ if vision_x is not None:
126
+ vision_tokens = self.vision_tokenizer(self._encode_vision_x(vision_x=vision_x))
127
+ else:
128
+ vision_tokens = None
129
+
130
+ # fuse the vision and language tokens
131
+ new_inputs = self._prepare_inputs_for_forward(
132
+ vision_tokens=vision_tokens,
133
+ lang_x=lang_x,
134
+ attention_mask=attention_mask,
135
+ vision_attention_mask=vision_attention_mask,
136
+ labels=labels,
137
+ past_key_values=past_key_values,
138
+ past_media_locations=past_media_locations,
139
+ padding_side="right",
140
+ past_vision_tokens=past_vision_tokens,
141
+ )
142
+ output = self.lang_model(
143
+ **new_inputs,
144
+ use_cache=use_cache,
145
+ past_key_values=past_key_values,
146
+ **kwargs,
147
+ )
148
+
149
+ # postforward hooks
150
+ self._post_forward_hook()
151
+ return output
152
+
153
+ def generate(
154
+ self,
155
+ vision_x: torch.Tensor,
156
+ lang_x: torch.Tensor,
157
+ attention_mask: torch.Tensor = None,
158
+ past_key_values: Optional[
159
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
160
+ ] = None,
161
+ past_media_locations: Optional[torch.Tensor] = None,
162
+ past_vision_tokens: Optional[torch.Tensor] = None,
163
+ **kwargs,
164
+ ):
165
+ """
166
+ Generate text conditioned on vision and language inputs.
167
+ Args:
168
+ vision_x (torch.Tensor): Vision input
169
+ shape (B, T_img, F, C, H, W)
170
+ see documentation for forward
171
+ lang_x (torch.Tensor): Language input
172
+ shape (B, T_txt)
173
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
174
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
175
+ Returns:
176
+ torch.Tensor: lang_x with generated tokens appended to it
177
+ """
178
+ num_beams = kwargs.pop("num_beams", 1)
179
+
180
+ # convert pixels to vision tokens
181
+ vision_attention_mask = None
182
+ if vision_x is not None:
183
+ vision_tokens = self.vision_tokenizer(self._encode_vision_x(vision_x=vision_x))
184
+ else:
185
+ vision_tokens = None
186
+
187
+ # fuse the vision and language tokens
188
+ new_inputs = self._prepare_inputs_for_forward(
189
+ vision_tokens=vision_tokens,
190
+ lang_x=lang_x,
191
+ attention_mask=attention_mask,
192
+ vision_attention_mask=vision_attention_mask,
193
+ past_key_values=past_key_values,
194
+ past_media_locations=past_media_locations,
195
+ past_vision_tokens=past_vision_tokens,
196
+ padding_side="left",
197
+ num_beams=num_beams,
198
+ )
199
+
200
+ # customize handling of position_ids since attention mask is already formulated as 4D
201
+ if len(new_inputs["attention_mask"].shape) == 4:
202
+ position_ids = new_inputs.get("position_ids", None)
203
+ if position_ids is None:
204
+ seq_length = new_inputs["inputs_embeds"].shape[1]
205
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=new_inputs["inputs_embeds"].device)
206
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
207
+ new_inputs["position_ids"] = position_ids
208
+
209
+ if past_key_values is not None:
210
+ output = self.lang_model.generate(
211
+ **new_inputs,
212
+ past_key_values=past_key_values,
213
+ num_beams=num_beams,
214
+ use_cache=True,
215
+ **kwargs,
216
+ )
217
+ else:
218
+ output = self.lang_model.generate(
219
+ **new_inputs,
220
+ num_beams=num_beams,
221
+ use_cache=True,
222
+ **kwargs,
223
+ )
224
+ self._post_forward_hook()
225
+ return output
226
+
src/aki_generation.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from transformers.utils import ModelOutput
4
+ from typing import Any, Dict
5
+
6
+
7
+ def update_causal_attention_mask(attention_mask, cache=False):
8
+ """
9
+ Updates a causal attention mask by expanding it to (n+1, n+1) during generation.
10
+
11
+ Parameters:
12
+ attention_mask (torch.Tensor): Current causal attention mask of shape (1, 1, n, n).
13
+
14
+ Returns:
15
+ torch.Tensor: Updated causal attention mask of shape (1, 1, n+1, n+1).
16
+ """
17
+ # Get the current size `n`
18
+ _, _, n, _ = attention_mask.shape
19
+
20
+ # Create a new row and column with -inf values
21
+ new_row = torch.full((1, 1, 1, n), 1, device=attention_mask.device)
22
+ new_col = torch.full((1, 1, n+1, 1), 0, device=attention_mask.device)
23
+
24
+ new_col[0, 0, -1, -1] = 1
25
+
26
+ # Concatenate the new row and column to the existing mask
27
+ attention_mask = torch.cat([attention_mask, new_row], dim=2) # Add the new row
28
+ attention_mask = torch.cat([attention_mask, new_col], dim=3) # Add the new column
29
+
30
+ if cache:
31
+ return attention_mask[:, :, -1:, :]
32
+ else:
33
+ return attention_mask
34
+
35
+
36
+ def _aki_update_model_kwargs_for_generation(
37
+ self,
38
+ outputs: ModelOutput,
39
+ model_kwargs: Dict[str, Any],
40
+ is_encoder_decoder: bool = False,
41
+ standardize_cache_format: bool = False,
42
+ num_new_tokens: int = 1,
43
+ ) -> Dict[str, Any]:
44
+ # update past_key_values
45
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
46
+ outputs, standardize_cache_format=standardize_cache_format
47
+ )
48
+ if getattr(outputs, "state", None) is not None:
49
+ model_kwargs["state"] = outputs.state
50
+
51
+ # update token_type_ids with last value
52
+ if "token_type_ids" in model_kwargs:
53
+ token_type_ids = model_kwargs["token_type_ids"]
54
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
55
+
56
+ if not is_encoder_decoder:
57
+ # update attention mask
58
+ if "attention_mask" in model_kwargs:
59
+ # modify the update mechanism to incorporate 4D attention mask
60
+ attention_mask = model_kwargs["attention_mask"]
61
+ # after the first computation, roll back to the original attention 2D design to fit Huggingface logistics
62
+ model_kwargs["attention_mask"] = torch.full((1, attention_mask.shape[-1]+1), 1, device=attention_mask.device)
63
+ else:
64
+ # update decoder attention mask
65
+ if "decoder_attention_mask" in model_kwargs:
66
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
67
+ model_kwargs["decoder_attention_mask"] = torch.cat(
68
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
69
+ dim=-1,
70
+ )
71
+
72
+ if (
73
+ model_kwargs.get("use_cache", True)
74
+ and "cache_position" in model_kwargs
75
+ and model_kwargs["cache_position"] is not None
76
+ ):
77
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
78
+
79
+ # update position_ids and keep only the last one
80
+ position_ids = torch.arange(model_kwargs["past_key_values"][0][0].shape[2]+1, device=model_kwargs["attention_mask"].device).unsqueeze(0) # +1 for the new token
81
+ if model_kwargs.get("past_key_values", None) is not None:
82
+ position_ids = position_ids[:, -1:]
83
+
84
+ model_kwargs["position_ids"] = position_ids
85
+
86
+ return model_kwargs
src/helpers.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on: https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import re
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from einops_exts import rearrange_many
10
+ from torch import einsum, nn
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from typing import Optional
13
+ from dataclasses import dataclass
14
+
15
+
16
+ @dataclass
17
+ class VLMOutputWithPast(CausalLMOutputWithPast):
18
+ """
19
+ VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
20
+ past_media_locations: Optional[torch.Tensor] = None,
21
+ past_vision_tokens: Optional[torch.Tensor] = None,
22
+ """
23
+
24
+ past_media_locations: Optional[torch.Tensor] = None
25
+ past_vision_tokens: Optional[torch.Tensor] = None
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def FeedForward(dim, mult=4):
33
+ inner_dim = int(dim * mult)
34
+ return nn.Sequential(
35
+ nn.LayerNorm(dim),
36
+ nn.Linear(dim, inner_dim, bias=False),
37
+ nn.GELU(),
38
+ nn.Linear(inner_dim, dim, bias=False),
39
+ )
40
+
41
+
42
+ class VisionTokenizer(nn.Module):
43
+ def __init__(self, dim_media, num_tokens_per_media):
44
+ super().__init__()
45
+ self.dim_media = dim_media
46
+ self.num_tokens_per_media = num_tokens_per_media
47
+
48
+
49
+ # MLP (not used in the current implementation)
50
+ class MLPVisionProjector(VisionTokenizer):
51
+ def __init__(self, *, dim, dim_inner, num_latents):
52
+ super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
53
+ self.projector = nn.Sequential(
54
+ nn.Linear(dim, dim_inner),
55
+ nn.GELU(),
56
+ nn.Linear(dim_inner, dim_inner),
57
+ )
58
+
59
+ def forward(self, x):
60
+ return self.projector(x)
61
+
62
+ class PerceiverAttention(nn.Module):
63
+ def __init__(self, *, dim, dim_head=64, heads=8):
64
+ super().__init__()
65
+ self.scale = dim_head**-0.5
66
+ self.heads = heads
67
+ inner_dim = dim_head * heads
68
+
69
+ self.norm_media = nn.LayerNorm(dim)
70
+ self.norm_latents = nn.LayerNorm(dim)
71
+
72
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
73
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
74
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
75
+
76
+ def forward(self, x, latents):
77
+ """
78
+ Args:
79
+ x (torch.Tensor): image features
80
+ shape (b, T, n1, D)
81
+ latent (torch.Tensor): latent features
82
+ shape (b, T, n2, D)
83
+ """
84
+ x = self.norm_media(x)
85
+ latents = self.norm_latents(latents)
86
+
87
+ h = self.heads
88
+
89
+ q = self.to_q(latents)
90
+ kv_input = torch.cat((x, latents), dim=-2)
91
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
92
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
93
+ q = q * self.scale
94
+
95
+ # attention
96
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
97
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
98
+ attn = sim.softmax(dim=-1)
99
+
100
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
101
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
102
+ return self.to_out(out)
103
+
104
+
105
+ class PerceiverResampler(VisionTokenizer):
106
+ def __init__(
107
+ self,
108
+ *,
109
+ dim,
110
+ dim_inner=None,
111
+ depth=6,
112
+ dim_head=64,
113
+ heads=8,
114
+ num_latents=64,
115
+ max_num_media=None,
116
+ max_num_frames=None,
117
+ ff_mult=4,
118
+ ):
119
+ """
120
+ Perceiver module which takes in image features and outputs image tokens.
121
+ Args:
122
+ dim (int): dimension of the incoming image features
123
+ dim_inner (int, optional): final dimension to project the incoming image features to;
124
+ also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
125
+ depth (int, optional): number of layers. Defaults to 6.
126
+ dim_head (int, optional): dimension of each head. Defaults to 64.
127
+ heads (int, optional): number of heads. Defaults to 8.
128
+ num_latents (int, optional): number of latent tokens to use in the Perceiver;
129
+ also corresponds to number of tokens per sequence to output. Defaults to 64.
130
+ max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
131
+ and keep positional embeddings for. If None, no positional embeddings are used.
132
+ max_num_frames (int, optional): maximum number of frames to input into the Perceiver
133
+ and keep positional embeddings for. If None, no positional embeddings are used.
134
+ ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
135
+ """
136
+ if dim_inner is not None:
137
+ projection = nn.Linear(dim, dim_inner)
138
+ else:
139
+ projection = None
140
+ dim_inner = dim
141
+ super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
142
+ self.projection = projection
143
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
144
+
145
+ # positional embeddings
146
+ self.frame_embs = (
147
+ nn.Parameter(torch.randn(max_num_frames, dim))
148
+ if exists(max_num_frames)
149
+ else None
150
+ )
151
+ self.media_time_embs = (
152
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
153
+ if exists(max_num_media)
154
+ else None
155
+ )
156
+
157
+ self.layers = nn.ModuleList([])
158
+ for _ in range(depth):
159
+ self.layers.append(
160
+ nn.ModuleList(
161
+ [
162
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
163
+ FeedForward(dim=dim, mult=ff_mult),
164
+ ]
165
+ )
166
+ )
167
+
168
+ self.norm = nn.LayerNorm(dim)
169
+
170
+ def forward(self, x):
171
+ """
172
+ Args:
173
+ x (torch.Tensor): image features
174
+ shape (b, T, F, v, D)
175
+ Returns:
176
+ shape (b, T, n, D) where n is self.num_latents
177
+ """
178
+ b, T, F, v = x.shape[:4]
179
+
180
+ # frame and media time embeddings
181
+ if exists(self.frame_embs):
182
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
183
+ x = x + frame_embs
184
+ x = rearrange(
185
+ x, "b T F v d -> b T (F v) d"
186
+ ) # flatten the frame and spatial dimensions
187
+ if exists(self.media_time_embs):
188
+ x = x + self.media_time_embs[:T]
189
+
190
+ # blocks
191
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
192
+ for attn, ff in self.layers:
193
+ latents = attn(x, latents) + latents
194
+ latents = ff(latents) + latents
195
+
196
+ if exists(self.projection):
197
+ return self.projection(self.norm(latents))
198
+ else:
199
+ return self.norm(latents)
200
+
201
+
202
+ # gated cross attention
203
+ class MaskedCrossAttention(nn.Module):
204
+ def __init__(
205
+ self,
206
+ *,
207
+ dim,
208
+ dim_visual,
209
+ dim_head=64,
210
+ heads=8,
211
+ only_attend_immediate_media=True,
212
+ ):
213
+ super().__init__()
214
+ self.scale = dim_head**-0.5
215
+ self.heads = heads
216
+ inner_dim = dim_head * heads
217
+
218
+ self.norm = nn.LayerNorm(dim)
219
+
220
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
221
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
222
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
223
+
224
+ # whether for text to only attend to immediate preceding image, or all previous images
225
+ self.only_attend_immediate_media = only_attend_immediate_media
226
+
227
+ def forward(self, x, media, media_locations=None, use_cached_media=False):
228
+ """
229
+ Args:
230
+ x (torch.Tensor): text features
231
+ shape (B, T_txt, D_txt)
232
+ media (torch.Tensor): image features
233
+ shape (B, T_img, n, D_img) where n is the dim of the latents
234
+ media_locations: boolean mask identifying the media tokens in x
235
+ shape (B, T_txt)
236
+ use_cached_media: bool
237
+ If true, treat all of x as if they occur after the last media
238
+ registered in media_locations. T_txt does not need to exactly
239
+ equal media_locations.shape[1] in this case
240
+ """
241
+
242
+ if not use_cached_media:
243
+ assert (
244
+ media_locations.shape[1] == x.shape[1]
245
+ ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
246
+
247
+ T_txt = x.shape[1]
248
+ _, T_img, n = media.shape[:3]
249
+ h = self.heads
250
+
251
+ x = self.norm(x)
252
+
253
+ q = self.to_q(x)
254
+ media = rearrange(media, "b t n d -> b (t n) d")
255
+
256
+ k, v = self.to_kv(media).chunk(2, dim=-1)
257
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
258
+
259
+ q = q * self.scale
260
+
261
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
262
+
263
+ if exists(media_locations):
264
+ media_time = torch.arange(T_img, device=x.device) + 1
265
+
266
+ if use_cached_media:
267
+ # text time is set to the last cached media location
268
+ text_time = repeat(
269
+ torch.count_nonzero(media_locations, dim=1),
270
+ "b -> b i",
271
+ i=T_txt,
272
+ )
273
+ else:
274
+ # at each boolean of True, increment the time counter (relative to media time)
275
+ text_time = media_locations.cumsum(dim=-1)
276
+
277
+ # text time must equal media time if only attending to most immediate image
278
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
279
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
280
+
281
+ text_to_media_mask = mask_op(
282
+ rearrange(text_time, "b i -> b 1 i 1"),
283
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
284
+ )
285
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
286
+
287
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
288
+ attn = sim.softmax(dim=-1)
289
+
290
+ if exists(media_locations) and self.only_attend_immediate_media:
291
+ # any text without a preceding media needs to have attention zeroed out
292
+ text_without_media_mask = text_time == 0
293
+ text_without_media_mask = rearrange(
294
+ text_without_media_mask, "b i -> b 1 i 1"
295
+ )
296
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
297
+
298
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
299
+ out = rearrange(out, "b h n d -> b n (h d)")
300
+ return self.to_out(out)
301
+
302
+
303
+ class GatedCrossAttentionBlock(nn.Module):
304
+ def __init__(
305
+ self,
306
+ *,
307
+ dim,
308
+ dim_visual,
309
+ dim_head=64,
310
+ heads=8,
311
+ ff_mult=4,
312
+ only_attend_immediate_media=True,
313
+ ):
314
+ super().__init__()
315
+ self.attn = MaskedCrossAttention(
316
+ dim=dim,
317
+ dim_visual=dim_visual,
318
+ dim_head=dim_head,
319
+ heads=heads,
320
+ only_attend_immediate_media=only_attend_immediate_media,
321
+ )
322
+ self.attn_gate = nn.Parameter(torch.tensor([0.0]))
323
+
324
+ self.ff = FeedForward(dim, mult=ff_mult)
325
+ self.ff_gate = nn.Parameter(torch.tensor([0.0]))
326
+
327
+ def forward(
328
+ self,
329
+ x,
330
+ media,
331
+ media_locations=None,
332
+ use_cached_media=False,
333
+ ):
334
+ x = (
335
+ self.attn(
336
+ x,
337
+ media,
338
+ media_locations=media_locations,
339
+ use_cached_media=use_cached_media,
340
+ )
341
+ * self.attn_gate.tanh()
342
+ + x
343
+ )
344
+ x = self.ff(x) * self.ff_gate.tanh() + x
345
+
346
+ return x
347
+
348
+
349
+ # Both DecoupledEmbedding and DecoupledLinear are taken from https://github.com/huggingface/transformers/blob/v4.32.1/src/transformers/models/idefics/modeling_idefics.py and renamed for clarity
350
+ class DecoupledEmbedding(nn.Embedding):
351
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
352
+ """
353
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
354
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
355
+ then it will create `num_additional_embeddings` additional parameters that are always trained. If
356
+ `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
357
+ """
358
+
359
+ def __init__(
360
+ self,
361
+ max_original_id: int,
362
+ num_additional_embeddings: int = 0,
363
+ _weight: torch.Tensor = None,
364
+ num_original_embeddings: int = None,
365
+ embedding_dim: int = None,
366
+ partially_freeze=True,
367
+ device=None,
368
+ dtype=None,
369
+ pad_token_id=None,
370
+ ) -> None:
371
+ """
372
+ Args:
373
+ max_original_id (`int`):
374
+ The largest token id that should be embedded using the regular embedding (regular `weight`).
375
+ This is usually len(tokenizer) - 1 before additional tokens are added.
376
+ Note that this may not equal self.weight.shape[0]
377
+ num_additional_embeddings (`int`):
378
+ Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
379
+ _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
380
+ If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
381
+ num_original_embeddings (`int`):
382
+ self.weight.shape[0]
383
+ embedding_dim (`int`):
384
+ The size of each embedding vector
385
+ partially_freeze: (`bool`, *optional*, defaults to `True`):
386
+ If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
387
+ padding_idx (`int`, *optional*):
388
+ The padding index (needs to be less than num_embeddings)
389
+
390
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
391
+ `max_norm` or `norm_type`. We are not supporting these.
392
+ """
393
+ # validate args
394
+ if pad_token_id is not None and pad_token_id > max_original_id:
395
+ raise ValueError(
396
+ f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
397
+ + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
398
+ )
399
+ if _weight is not None:
400
+ assert (num_original_embeddings is None) or (
401
+ _weight.shape[0] == num_original_embeddings
402
+ ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
403
+ assert (embedding_dim is None) or (
404
+ _weight.shape[1] == embedding_dim
405
+ ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
406
+ num_original_embeddings = _weight.shape[0]
407
+ embedding_dim = _weight.shape[1]
408
+ else:
409
+ assert (
410
+ num_original_embeddings is not None
411
+ ), "num_original_embeddings must be provided if _weight is not provided"
412
+ assert (
413
+ embedding_dim is not None
414
+ ), "embedding_dim must be provided if _weight is not provided"
415
+
416
+ super().__init__(
417
+ num_embeddings=num_original_embeddings,
418
+ embedding_dim=embedding_dim,
419
+ device=device,
420
+ dtype=dtype,
421
+ padding_idx=pad_token_id,
422
+ _weight=_weight,
423
+ )
424
+ self.max_original_id = max_original_id
425
+ self.padding_idx = pad_token_id
426
+ self.num_additional_embeddings = num_additional_embeddings
427
+ if self.num_additional_embeddings > 0:
428
+ self.additional_embedding = nn.Embedding(
429
+ num_embeddings=self.num_additional_embeddings,
430
+ embedding_dim=embedding_dim,
431
+ device=device,
432
+ dtype=dtype,
433
+ )
434
+ self.set_requires_grad(
435
+ require_regular_grad=not partially_freeze, require_additional_grad=True
436
+ )
437
+
438
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
439
+ """
440
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
441
+ """
442
+ self.weight.requires_grad_(require_regular_grad)
443
+ self.additional_embedding.requires_grad_(require_additional_grad)
444
+
445
+ def forward(self, input_ids):
446
+ """
447
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
448
+ self.additional_embedding.weight that is being trained.
449
+
450
+ in order to make a lookup of the input ids, we:
451
+ 1. find out the indices of the entries belonging to the 2nd embedding
452
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
453
+ embedding starts from 0 and not num_embeddings
454
+ 3. perform the 2nd embedding lookup
455
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
456
+ 5. perform the 1st embedding lookup
457
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
458
+
459
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
460
+ then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
461
+ i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
462
+ usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
463
+ measure.
464
+
465
+ """
466
+ if self.num_additional_embeddings == 0:
467
+ return F.embedding(input_ids, self.weight)
468
+
469
+ # Clone so that we don't modify the original input_ids later on
470
+ input_ids = input_ids.clone()
471
+ additional_vocab_indices = torch.where(input_ids > self.max_original_id)
472
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
473
+ additional_embeddings = self.additional_embedding(
474
+ input_ids_additional_vocab - self.max_original_id - 1
475
+ )
476
+
477
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
478
+ input_ids[additional_vocab_indices] = 0
479
+ full_vector = F.embedding(input_ids, self.weight)
480
+
481
+ # overwrite the records with high indices
482
+ full_vector[additional_vocab_indices] = additional_embeddings
483
+
484
+ return full_vector
485
+
486
+ def extra_repr(self) -> str:
487
+ return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
488
+ self.max_original_id + 1,
489
+ self.num_additional_embeddings,
490
+ self.embedding_dim,
491
+ (not self.weight.requires_grad),
492
+ )
493
+
494
+
495
+ class DecoupledLinear(nn.Linear):
496
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
497
+ """
498
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
499
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
500
+ then it will create `additional_out_features * in_features` additional parameters that are always trained. If
501
+ `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
502
+ """
503
+
504
+ def __init__(
505
+ self,
506
+ max_original_id: int,
507
+ additional_out_features: int = 0,
508
+ _weight: torch.Tensor = None,
509
+ _bias: torch.Tensor = None,
510
+ in_features: int = None,
511
+ original_out_features: int = None,
512
+ bias: bool = True,
513
+ partially_freeze: bool = True,
514
+ device=None,
515
+ dtype=None,
516
+ ) -> None:
517
+ """
518
+ Args:
519
+ max_original_id (`int`): The largest token id that should be extracted from the regular weight.
520
+ This is usually len(tokenizer) - 1 before additional tokens are added.
521
+ Note that this may not equal original_out_features - 1
522
+ _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
523
+ If provided, this sets the `in_features` and `original_out_features` parameters.
524
+ _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
525
+ in_features: int. Input hidden size.
526
+ original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
527
+ additional_out_features: int. Number of additional trainable dimensions.
528
+ bias: bool. Whether to include a bias term.
529
+ partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
530
+ """
531
+ # argument validation
532
+ if _weight is not None:
533
+ assert (_weight.shape[0] == original_out_features) or (
534
+ original_out_features is None
535
+ ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
536
+ assert (_weight.shape[1] == in_features) or (
537
+ in_features is None
538
+ ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
539
+ in_features = _weight.shape[1]
540
+ original_out_features = _weight.shape[0]
541
+ else:
542
+ assert (
543
+ in_features is not None
544
+ ), "in_features must be provided if _weight is not provided"
545
+ assert (
546
+ original_out_features is not None
547
+ ), "original_out_features must be provided if _weight is not provided"
548
+
549
+ if _bias is not None:
550
+ assert bias is True, "bias must be True if _bias is provided"
551
+
552
+ # initialize original linear
553
+ super().__init__(
554
+ in_features,
555
+ original_out_features,
556
+ bias,
557
+ device,
558
+ dtype)
559
+
560
+ # set weight and bias manually
561
+ if _weight is not None:
562
+ self.weight = nn.Parameter(_weight)
563
+ if _bias is not None:
564
+ self.bias = nn.Parameter(_bias)
565
+
566
+ self.in_features = in_features
567
+ self.original_out_features = original_out_features
568
+ self.max_original_id = max_original_id
569
+
570
+ # initialize additional linear
571
+ self.additional_out_features = additional_out_features
572
+ self.has_bias = bias
573
+ if additional_out_features > 0:
574
+ self.additional_fc = nn.Linear(
575
+ in_features=in_features,
576
+ out_features=additional_out_features,
577
+ bias=self.has_bias,
578
+ device=device,
579
+ dtype=dtype,
580
+ )
581
+ self.set_requires_grad(
582
+ require_regular_grad=not partially_freeze, require_additional_grad=True
583
+ )
584
+
585
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
586
+ """
587
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
588
+ """
589
+ self.weight.requires_grad_(require_regular_grad)
590
+ if self.has_bias:
591
+ self.bias.requires_grad_(require_regular_grad)
592
+ self.additional_fc.requires_grad_(require_additional_grad)
593
+
594
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
595
+ output = F.linear(input, self.weight, self.bias)
596
+ output = output[..., : self.max_original_id + 1]
597
+
598
+ if self.additional_out_features > 0:
599
+ additional_features = F.linear(
600
+ input, self.additional_fc.weight, self.additional_fc.bias
601
+ )
602
+ output = torch.cat((output, additional_features), -1)
603
+ return output
604
+
605
+ def extra_repr(self) -> str:
606
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
607
+ return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
608
+ self.in_features,
609
+ self.max_original_id + 1,
610
+ self.additional_out_features,
611
+ self.bias is not None,
612
+ (not self.weight.requires_grad or not self.bias.requires_grad),
613
+ )
src/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def extend_instance(obj, mixin):
5
+ """Apply mixins to a class instance after creation"""
6
+ base_cls = obj.__class__
7
+ base_cls_name = obj.__class__.__name__
8
+ obj.__class__ = type(
9
+ base_cls_name, (mixin, base_cls), {}
10
+ ) # mixin needs to go first for our forward() logic to work
11
+
12
+
13
+ def getattr_recursive(obj, att):
14
+ """
15
+ Return nested attribute of obj
16
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
17
+ """
18
+ if att == "":
19
+ return obj
20
+ i = att.find(".")
21
+ if i < 0:
22
+ return getattr(obj, att)
23
+ else:
24
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
25
+
26
+
27
+ def setattr_recursive(obj, att, val):
28
+ """
29
+ Set nested attribute of obj
30
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
31
+ """
32
+ if "." in att:
33
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
34
+ setattr(obj, att.split(".")[-1], val)
35
+
36
+
37
+ def apply_with_stopping_condition(
38
+ module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
39
+ ):
40
+ if stopping_condition(module):
41
+ return
42
+ if apply_condition(module):
43
+ apply_fn(module, **other_args)
44
+ for child in module.children():
45
+ apply_with_stopping_condition(
46
+ child,
47
+ apply_fn,
48
+ apply_condition=apply_condition,
49
+ stopping_condition=stopping_condition,
50
+ **other_args
51
+ )
52
+
53
+
54
+ def num_params(module, filter_to_trainable=False):
55
+ """Returns the number of parameters in the module, or optionally only the trainable parameters"""
56
+ if filter_to_trainable:
57
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
58
+ else:
59
+ return sum(p.numel() for p in module.parameters())
60
+
61
+
62
+ def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
63
+ """
64
+ Stack a list of tensors with padding on one side
65
+ Args:
66
+ list_of_tensors (list[torch.Tensor]): List of tensors to stack
67
+ padding_value (int, optional): Value to pad with. Defaults to 0.
68
+ padding_side (str, optional): Side to pad on. Defaults to "right".
69
+ Returns:
70
+ torch.Tensor: Stacked tensors
71
+ """
72
+ max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
73
+ padded_tensors = []
74
+ for tensor in list_of_tensors:
75
+ num_tokens = tensor.size(0)
76
+ if len(tensor.size()) == 1:
77
+ padding = torch.full(
78
+ (max_tokens - num_tokens,),
79
+ padding_value,
80
+ dtype=tensor.dtype,
81
+ device=tensor.device,
82
+ )
83
+ else:
84
+ padding = torch.full(
85
+ (max_tokens - num_tokens, tensor.size(1)),
86
+ padding_value,
87
+ dtype=tensor.dtype,
88
+ device=tensor.device,
89
+ )
90
+ padded_tensor = (
91
+ torch.cat((tensor, padding), dim=0)
92
+ if padding_side == "right"
93
+ else torch.cat((padding, tensor), dim=0)
94
+ )
95
+ padded_tensors.append(padded_tensor)
96
+ return torch.stack(padded_tensors)
97
+
98
+
99
+ def stack_with_padding_2D_attention(list_of_tensors):
100
+ max_size = max(tensor.size(1) for tensor in list_of_tensors)
101
+ # Initialize a padded tensor of zeros with the target shape
102
+ padded_tensors = []
103
+ for tensor in list_of_tensors:
104
+ a = tensor.shape[-1]
105
+ padding = (0, max_size - a, 0, max_size - a) # (left, right, top, bottom)
106
+ padded_tensor = torch.nn.functional.pad(tensor, padding)
107
+ padded_tensors.append(padded_tensor)
108
+ return torch.stack(padded_tensors)
src/vlm.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+ from typing import List, Optional, Tuple, Union
5
+ from .utils import extend_instance, stack_with_padding, stack_with_padding_2D_attention, num_params, getattr_recursive
6
+ from .helpers import DecoupledEmbedding, DecoupledLinear, VLMOutputWithPast
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers import CLIPVisionModel
9
+ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
10
+
11
+
12
+ class VLM(nn.Module):
13
+ """
14
+ Generic vision-language model (VLM) class.
15
+ A VLM consists of four components:
16
+ 1. A vision encoder that extracts features from pixels, e.g. CLIP
17
+ input: (B, T_img, F, C, H, W)
18
+ output: (B, T_img, F, v, d)
19
+ 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
20
+ input: (B, T_img, F, v, d)
21
+ output: (B, T_img, n, d)
22
+ 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
23
+ 4. A language model
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ vision_encoder: nn.Module,
29
+ vision_tokenizer: nn.Module,
30
+ lang_model: nn.Module,
31
+ initial_tokenizer_len: int,
32
+ pad_token_id: int,
33
+ gradient_checkpointing: bool = False,
34
+ base_img_size: Optional[int] = None,
35
+ ):
36
+ """
37
+ Args:
38
+ vision_encoder (nn.Module): e.g. CLIP
39
+ vision_tokenizer (nn.Module): e.g. PerceiverResampler
40
+ lang_model (nn.Module): e.g. MPT
41
+ initial_tokenizer_len (int): size of the original tokenizer vocab
42
+ pad_token_id (int): id of the pad token
43
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
44
+ """
45
+ super().__init__()
46
+
47
+ # save dimension information
48
+ self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
49
+ if hasattr(lang_model.config, "d_model"):
50
+ self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
51
+ else:
52
+ self.lang_hidden_dim = lang_model.config.hidden_size
53
+ self.vis_embedding_dim = vision_tokenizer.dim_media
54
+ self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
55
+
56
+ # core components
57
+ self.vision_encoder = vision_encoder
58
+ self.vision_tokenizer = vision_tokenizer
59
+ self.lang_model = lang_model
60
+
61
+ if base_img_size is None:
62
+ if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(self.vision_encoder, SiglipVisionTransformer):
63
+ base_img_size = self.vision_encoder.config.image_size
64
+ else:
65
+ base_img_size = self.vision_encoder.image_size[0]
66
+ self.base_img_size = base_img_size
67
+
68
+ # lm embeddings
69
+ self.pad_token_id = pad_token_id
70
+ self.initial_tokenizer_len = initial_tokenizer_len
71
+ input_embeds = DecoupledEmbedding(
72
+ max_original_id=initial_tokenizer_len - 1,
73
+ num_additional_embeddings=len(self.special_tokens),
74
+ _weight=self.lang_model.get_input_embeddings().weight,
75
+ pad_token_id=self.pad_token_id,
76
+ )
77
+ if hasattr(input_embeds, "additional_embedding"):
78
+ input_embeds.additional_embedding.weight.data.normal_(
79
+ mean=0.0,
80
+ std=self.lang_model.config.initializer_range
81
+ if hasattr(self.lang_model.config, "initializer_range")
82
+ else 0.02,
83
+ )
84
+ self.lang_model.set_input_embeddings(input_embeds)
85
+
86
+ out_embeds = DecoupledLinear(
87
+ max_original_id=initial_tokenizer_len - 1,
88
+ additional_out_features=len(self.special_tokens),
89
+ _weight=self.lang_model.get_output_embeddings().weight,
90
+ _bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
91
+ )
92
+ if hasattr(out_embeds, "additional_fc"):
93
+ out_embeds.additional_fc.weight.data.normal_(
94
+ mean=0.0,
95
+ std=self.lang_model.config.initializer_range
96
+ if hasattr(self.lang_model.config, "initializer_range")
97
+ else 0.02,
98
+ )
99
+ self.lang_model.set_output_embeddings(out_embeds)
100
+
101
+ # gradient checkpointing
102
+ self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
103
+
104
+ def forward(
105
+ self,
106
+ vision_x: Optional[torch.Tensor],
107
+ lang_x: torch.Tensor,
108
+ attention_mask: Optional[torch.Tensor] = None,
109
+ labels: Optional[torch.Tensor] = None,
110
+ past_key_values: Optional[
111
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
112
+ ] = None,
113
+ past_media_locations: Optional[torch.Tensor] = None,
114
+ past_vision_tokens: Optional[torch.Tensor] = None,
115
+ use_cache: Optional[bool] = False,
116
+ **kwargs,
117
+ ):
118
+ """
119
+ Args:
120
+ vision_x: Vision input
121
+ shape (B, T_img, F, C, H, W) with F=1
122
+ only F = 1 is supported (single-frame videos)
123
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
124
+ only the first number of media tokens in lang_x are used
125
+ lang_x: Language input ids, with media tokens denoting where
126
+ visual media should be inserted.
127
+ shape (B, T_txt)
128
+ attention_mask: Attention mask. Defaults to None.
129
+ labels: Labels. Defaults to None.
130
+ shape (B, T_txt)
131
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
132
+ list of length = number of decoder layers in the LM
133
+ exact implementation depends on LM, see Hugging Face docs
134
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
135
+ shape (B, T_txt)
136
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
137
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
138
+ If True, includes key_values, media_locations, and vision_tokens in the output.
139
+ """
140
+ assert not (past_vision_tokens is None) ^ (
141
+ past_media_locations is None
142
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
143
+
144
+ # convert pixels to vision tokens
145
+ if vision_x is not None:
146
+ vision_features = self._encode_vision_x(vision_x=vision_x)
147
+ vision_tokens = self.vision_tokenizer(vision_features)
148
+ else:
149
+ vision_tokens = None
150
+
151
+ # fuse the vision and language tokens
152
+ new_inputs = self._prepare_inputs_for_forward(
153
+ vision_tokens=vision_tokens,
154
+ lang_x=lang_x,
155
+ attention_mask=attention_mask,
156
+ labels=labels,
157
+ past_key_values=past_key_values,
158
+ past_media_locations=past_media_locations,
159
+ padding_side="right",
160
+ past_vision_tokens=past_vision_tokens,
161
+ )
162
+ output = self.lang_model(
163
+ **new_inputs,
164
+ use_cache=use_cache,
165
+ past_key_values=past_key_values,
166
+ **kwargs,
167
+ )
168
+
169
+ # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
170
+ # or to add the past_vision_tokens and past_media_locations to the output
171
+ output = self._postprocess_outputs_from_forward(
172
+ output=output,
173
+ lang_x=lang_x,
174
+ vision_tokens=vision_tokens,
175
+ use_cache=use_cache,
176
+ past_vision_tokens=past_vision_tokens,
177
+ past_media_locations=past_media_locations,
178
+ )
179
+
180
+ # postforward hooks
181
+ self._post_forward_hook()
182
+ return output
183
+
184
+ def _encode_vision_x(self, vision_x: torch.Tensor):
185
+ """
186
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
187
+ Args:
188
+ vision_x: Vision input
189
+ shape (B, T_img, F, C, H, W)
190
+ Images in the same chunk are collated along T_img, and frames are collated along F
191
+ Currently only F=1 is supported (single-frame videos)
192
+
193
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
194
+ """
195
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
196
+ b, T, F = vision_x.shape[:3]
197
+
198
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
199
+ with torch.no_grad():
200
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
201
+ vision_x = self.vision_encoder.trunk.forward_features(vision_x)
202
+ elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
203
+ vision_x = self.vision_encoder(vision_x).last_hidden_state
204
+ else:
205
+ vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
206
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
207
+ return vision_x
208
+
209
+ def _concat_vision_cache(
210
+ self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
211
+ ):
212
+ """
213
+ Helper function to include the past vision tokens and past media locations in the output.
214
+ """
215
+ if use_cache:
216
+ if past_media_locations is not None and past_vision_tokens is not None:
217
+ if vision_tokens is not None:
218
+ updated_vision_tokens = torch.cat(
219
+ [
220
+ past_vision_tokens,
221
+ vision_tokens,
222
+ ],
223
+ dim=1,
224
+ )
225
+ else:
226
+ updated_vision_tokens = past_vision_tokens
227
+ updated_media_locations = torch.cat(
228
+ [
229
+ past_media_locations,
230
+ lang_x == self.media_token_id,
231
+ ],
232
+ dim=1,
233
+ )
234
+ else:
235
+ updated_vision_tokens = vision_tokens
236
+ updated_media_locations = lang_x == self.media_token_id
237
+
238
+ else:
239
+ updated_vision_tokens = None
240
+ updated_media_locations = None
241
+
242
+ return updated_vision_tokens, updated_media_locations
243
+
244
+ def generate(
245
+ self,
246
+ vision_x: torch.Tensor,
247
+ lang_x: torch.Tensor,
248
+ attention_mask: torch.Tensor = None,
249
+ past_key_values: Optional[
250
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
251
+ ] = None,
252
+ past_media_locations: Optional[torch.Tensor] = None,
253
+ past_vision_tokens: Optional[torch.Tensor] = None,
254
+ **kwargs,
255
+ ):
256
+ """
257
+ Generate text conditioned on vision and language inputs.
258
+ Args:
259
+ vision_x (torch.Tensor): Vision input
260
+ shape (B, T_img, F, C, H, W)
261
+ see documentation for forward
262
+ lang_x (torch.Tensor): Language input
263
+ shape (B, T_txt)
264
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
265
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
266
+ Returns:
267
+ torch.Tensor: lang_x with generated tokens appended to it
268
+ """
269
+ num_beams = kwargs.pop("num_beams", 1)
270
+
271
+ # convert pixels to vision tokens
272
+ if vision_x is not None:
273
+ vision_features = self._encode_vision_x(vision_x=vision_x)
274
+ vision_tokens = self.vision_tokenizer(vision_features)
275
+ else:
276
+ vision_tokens = None
277
+
278
+ # fuse the vision and language tokens
279
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
280
+ # the total batch size is B * num_beams
281
+ new_inputs = self._prepare_inputs_for_forward(
282
+ vision_tokens=vision_tokens,
283
+ lang_x=lang_x,
284
+ attention_mask=attention_mask,
285
+ past_key_values=past_key_values,
286
+ past_media_locations=past_media_locations,
287
+ past_vision_tokens=past_vision_tokens,
288
+ padding_side="left",
289
+ num_beams=num_beams,
290
+ )
291
+ output = self.lang_model.generate(
292
+ **new_inputs,
293
+ past_key_values=past_key_values,
294
+ num_beams=num_beams,
295
+ use_cache=True,
296
+ **kwargs,
297
+ )
298
+ self._post_forward_hook()
299
+ return output
300
+
301
+ @property
302
+ def num_trainable_params(self):
303
+ """Print the number of trainable parameters"""
304
+ return num_params(self, filter_to_trainable=True)
305
+
306
+ def set_trainable(self):
307
+ """
308
+ Freeze appropriate parameters in the model.
309
+ """
310
+ raise NotImplementedError
311
+
312
+ def group_params_by_weight_decay(self):
313
+ """
314
+ Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
315
+ """
316
+ params_with_wd, params_without_wd = [], []
317
+ for n, p in self.named_parameters():
318
+ if p.requires_grad:
319
+ if self._should_apply_weight_decay(n):
320
+ params_with_wd.append(p)
321
+ else:
322
+ params_without_wd.append(p)
323
+ return params_with_wd, params_without_wd
324
+
325
+ def _should_apply_weight_decay(self, parameter_name):
326
+ """
327
+ Return whether weight decay should be applied to a parameter.
328
+ """
329
+ raise NotImplementedError
330
+
331
+ @property
332
+ def special_tokens(self):
333
+ """
334
+ Returns a dict mapping from the attribute name of a special token to its string format,
335
+ e.g. "media_token": "<image>"
336
+ """
337
+ assert (
338
+ "media_token" in self._special_tokens
339
+ ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
340
+ return self._special_tokens
341
+
342
+ @property
343
+ def special_token_ids(self):
344
+ """
345
+ Returns a list of the special token ids
346
+ """
347
+ return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
348
+
349
+ def set_special_token_ids(self, string_to_ids):
350
+ """
351
+ Args:
352
+ string_to_ids (dict): mapping from token string to id
353
+ """
354
+ assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
355
+ for att_name, token_str in self.special_tokens.items():
356
+ token_id = string_to_ids[token_str]
357
+ setattr(self, f"{att_name}_id", token_id)
358
+ setattr(self.lang_model, f"{att_name}_id", token_id)
359
+
360
+ def init_gradient_checkpointing(self):
361
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
362
+ checkpoint_wrapper,
363
+ CheckpointWrapper,
364
+ CheckpointImpl,
365
+ apply_activation_checkpointing,
366
+ )
367
+ from functools import partial
368
+
369
+ non_reentrant_wrapper = partial(
370
+ checkpoint_wrapper,
371
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
372
+ )
373
+ apply_activation_checkpointing(
374
+ self,
375
+ checkpoint_wrapper_fn=non_reentrant_wrapper,
376
+ check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
377
+ and not isinstance(m, CheckpointWrapper),
378
+ )
379
+
380
+
381
+ class VLMWithLanguageStream(VLM):
382
+ """
383
+ VLM that fuses modalities by inserting vision tokens directly into the language stream.
384
+ """
385
+
386
+ def __init__(
387
+ self,
388
+ vision_encoder: nn.Module,
389
+ vision_tokenizer: nn.Module,
390
+ lang_model: nn.Module,
391
+ initial_tokenizer_len: int,
392
+ pad_token_id: int,
393
+ decoder_layers_attr_name: str = None,
394
+ gradient_checkpointing: bool = False,
395
+ base_img_size: Optional[int] = None,
396
+ ):
397
+ super().__init__(
398
+ vision_encoder=vision_encoder,
399
+ vision_tokenizer=vision_tokenizer,
400
+ lang_model=lang_model,
401
+ initial_tokenizer_len=initial_tokenizer_len,
402
+ pad_token_id=pad_token_id,
403
+ base_img_size=base_img_size,
404
+ gradient_checkpointing=gradient_checkpointing,
405
+ )
406
+ self.decoder_layers_attr_name = decoder_layers_attr_name
407
+ for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
408
+ block._use_gradient_checkpointing = gradient_checkpointing
409
+
410
+ @staticmethod
411
+ def _make_modality_mutual_mask(
412
+ attention_mask_2d: torch.Tensor,
413
+ image_start_idx: int,
414
+ text_start_idx: int,
415
+ text_end_idx: int, # the end of the question in the SFT stage
416
+ input_ids_shape: torch.Size,
417
+ dtype: torch.dtype,
418
+ device: torch.device,
419
+ ):
420
+ """
421
+ Make non-causal mask between modalities.
422
+ """
423
+ tgt_len = input_ids_shape[0]
424
+ mask = torch.full((tgt_len, tgt_len), 0, device=device)
425
+ mask_cond = torch.arange(mask.size(-1), device=device)
426
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 1)
427
+
428
+ # enable vision tokens to attend to text tokens
429
+ mask[image_start_idx:text_start_idx, text_start_idx:text_end_idx] = 1
430
+
431
+ mask = mask.to(dtype)
432
+ mask = mask[None, :, :].expand(1, tgt_len, tgt_len)
433
+
434
+ expanded_mask = attention_mask_2d[None, None, :].expand(1, tgt_len, tgt_len).to(torch.float32)
435
+ inverted_mask = 1.0 - expanded_mask
436
+ expanded_attn_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(torch.float32).min)
437
+
438
+ expanded_attn_mask = mask.masked_fill(expanded_attn_mask.bool(), 0)
439
+
440
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
441
+ expanded_4d_mask = expanded_attn_mask
442
+
443
+ return expanded_4d_mask
444
+
445
+ def _prepare_inputs_for_forward(
446
+ self,
447
+ vision_tokens: torch.Tensor,
448
+ lang_x: torch.Tensor,
449
+ attention_mask: torch.Tensor,
450
+ labels: torch.Tensor = None,
451
+ past_key_values=None,
452
+ vision_attention_mask: Optional[torch.Tensor] = None,
453
+ past_media_locations: torch.Tensor = None,
454
+ past_vision_tokens: torch.Tensor = None,
455
+ padding_side: str = "left",
456
+ num_beams: int = 1,
457
+ ):
458
+ """
459
+ Insert the vision tokens directly into the language stream/
460
+ This requires us to modify the input_ids, attention_mask, and labels.
461
+ [NOTE]: This function can be changed to fit the ablation setting of putting text before images.
462
+ """
463
+ if past_key_values is not None:
464
+ past_len = past_key_values[0][0].shape[2]
465
+ assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
466
+ "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
467
+ + "Check that you've expanded the attention mask to account for past image tokens."
468
+ )
469
+
470
+ if vision_tokens is None:
471
+ return {
472
+ "input_ids": lang_x,
473
+ "attention_mask": attention_mask,
474
+ "labels": labels,
475
+ }
476
+
477
+ # get the language embeddings
478
+ lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
479
+
480
+ # build up the multimodal embeddings
481
+ B = lang_x.shape[0]
482
+ has_labels = labels is not None
483
+ multimodal_embeds = []
484
+ multimodal_attention_mask = []
485
+ multimodal_labels = [] if has_labels else None
486
+ for i in range(B):
487
+ # get index of <image> tokens in lang_x[i]
488
+ image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
489
+
490
+ # get the <|assistant|> token index, hardcode for now but can easily get from tokenizer's special tokens
491
+ # assume only one <|assistant|> token, i.e., single-turn
492
+ question_token_idx = torch.where(lang_x[i] == 32001)[0]
493
+ if len(question_token_idx) != 0:
494
+ question_token_idx = question_token_idx[0]
495
+ else:
496
+ question_token_idx = 0
497
+
498
+ if len(image_token_idxs) == 0:
499
+ multimodal_embeds.append(lang_embeds[i].clone())
500
+ new_attention_mask = self._make_modality_mutual_mask(
501
+ attention_mask_2d=attention_mask[i].clone(),
502
+ image_start_idx=0,
503
+ text_start_idx=0,
504
+ text_end_idx=question_token_idx,
505
+ input_ids_shape=attention_mask[i].shape,
506
+ dtype=attention_mask[i].dtype,
507
+ device=attention_mask[i].device,
508
+ )
509
+ multimodal_attention_mask.append(new_attention_mask)
510
+ if has_labels:
511
+ multimodal_labels.append(labels[i].clone())
512
+ continue
513
+
514
+ # since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
515
+ # loop through the image_token_idxs and insert the vision tokens
516
+ new_embed = lang_embeds[i].clone()
517
+ new_attention_mask = (
518
+ attention_mask[i].clone() if attention_mask is not None else None
519
+ )
520
+ if has_labels:
521
+ new_label = labels[i].clone()
522
+
523
+ for img_num in range(len(image_token_idxs)):
524
+ img_idx = image_token_idxs[img_num]
525
+ assert (
526
+ vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
527
+ ), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
528
+ vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
529
+ # By default, vision tokens are not padded.
530
+ num_vis_tokens = self.num_tokens_per_vis
531
+ vis_attention_mask = torch.ones(
532
+ num_vis_tokens, dtype=torch.long
533
+ ).to(attention_mask.device)
534
+
535
+ # Offset the rest of image tokens with current num_vis_tokens
536
+ for j in range(img_num+1, len(image_token_idxs)):
537
+ image_token_idxs[j] += num_vis_tokens
538
+
539
+ new_embed = torch.cat(
540
+ (
541
+ new_embed[:img_idx],
542
+ vision_tokens[i][img_num],
543
+ new_embed[img_idx + 1 :],
544
+ ),
545
+ dim=0,
546
+ )
547
+ new_attention_mask = torch.cat(
548
+ (
549
+ new_attention_mask[:img_idx],
550
+ vis_attention_mask,
551
+ new_attention_mask[img_idx + 1 :],
552
+ ),
553
+ dim=0,
554
+ )
555
+
556
+ new_attention_mask = self._make_modality_mutual_mask(
557
+ attention_mask_2d=new_attention_mask,
558
+ image_start_idx=img_idx,
559
+ text_start_idx=img_idx+len(vis_attention_mask), # 1+128 -> start position of text
560
+ text_end_idx=question_token_idx+len(vis_attention_mask),
561
+ input_ids_shape=new_attention_mask.shape, # (252)
562
+ dtype=new_attention_mask.dtype,
563
+ device=new_attention_mask.device,
564
+ )
565
+
566
+ if has_labels:
567
+ new_label = torch.cat(
568
+ (
569
+ new_label[:img_idx],
570
+ torch.ones(num_vis_tokens, dtype=torch.long).to(
571
+ labels.device
572
+ )
573
+ * -100,
574
+ new_label[img_idx + 1 :],
575
+ ),
576
+ dim=0,
577
+ )
578
+ multimodal_embeds.append(new_embed)
579
+ multimodal_attention_mask.append(new_attention_mask)
580
+ if has_labels:
581
+ multimodal_labels.append(new_label)
582
+
583
+ # stack
584
+ multimodal_embeds = stack_with_padding(
585
+ multimodal_embeds,
586
+ padding_value=self.pad_token_id,
587
+ padding_side=padding_side,
588
+ )
589
+ multimodal_attention_mask = stack_with_padding_2D_attention(
590
+ multimodal_attention_mask,
591
+ )
592
+ if has_labels:
593
+ multimodal_labels = stack_with_padding(
594
+ multimodal_labels,
595
+ padding_value=-100,
596
+ padding_side=padding_side,
597
+ )
598
+
599
+ return {
600
+ "inputs_embeds": multimodal_embeds,
601
+ "attention_mask": multimodal_attention_mask,
602
+ "labels": multimodal_labels,
603
+ }
604
+
605
+ def _postprocess_outputs_from_forward(
606
+ self,
607
+ output: CausalLMOutputWithPast,
608
+ lang_x: torch.Tensor,
609
+ vision_tokens: torch.Tensor,
610
+ past_vision_tokens: torch.Tensor,
611
+ past_media_locations: torch.Tensor,
612
+ use_cache: bool = False,
613
+ ):
614
+ # Include the past vision tokens and past media locations in the output
615
+ updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
616
+ lang_x=lang_x,
617
+ vision_tokens=vision_tokens,
618
+ past_vision_tokens=past_vision_tokens,
619
+ past_media_locations=past_media_locations,
620
+ use_cache=use_cache,
621
+ )
622
+
623
+ # return logits that are the same shape as the original input_ids
624
+ logits = output.logits
625
+ batch_logits = []
626
+ B, T_txt = lang_x.shape
627
+ for i in range(B):
628
+ sequence_logits = []
629
+ logits_j = 0
630
+ img_id = 0
631
+ for j in range(T_txt):
632
+ if lang_x[i, j] != self.media_token_id:
633
+ sequence_logits.append(logits[i, logits_j])
634
+ logits_j += 1
635
+ else:
636
+ # append the logit for the first image token, then skip over the rest
637
+ # note: the model actually learns to predict <im_patch>, not <image>
638
+ sequence_logits.append(logits[i, logits_j])
639
+ # logits_j += self.num_tokens_per_vis
640
+ # Offset in account of dynamic num_vis_tokens.
641
+ logits_j += vision_tokens[i][img_id].shape[0]
642
+ img_id += 1
643
+ sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
644
+ batch_logits.append(sequence_logits)
645
+
646
+ batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
647
+ # The final logits shape should be the same as the original input_ids shape
648
+ assert batch_logits.shape[:2] == (B, T_txt)
649
+
650
+ # assemble the output
651
+ output = VLMOutputWithPast(
652
+ loss=output.loss,
653
+ logits=batch_logits,
654
+ past_key_values=output.past_key_values,
655
+ hidden_states=output.hidden_states,
656
+ attentions=output.attentions,
657
+ past_media_locations=updated_media_locations,
658
+ past_vision_tokens=updated_vision_tokens,
659
+ )
660
+
661
+ return output
662
+
663
+ def _post_forward_hook(self):
664
+ pass
665
+
666
+ def get_fsdp_lambda_fn(self):
667
+ """
668
+ Returns the lambda function used to decide how to perform FSDP wrapping.
669
+ """
670
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
671
+ CheckpointWrapper,
672
+ )
673
+
674
+ decoder_block_class = getattr_recursive(
675
+ self.lang_model, self.decoder_layers_attr_name
676
+ )[0].__class__
677
+
678
+ def lambda_fn(module: nn.Module):
679
+ if getattr(module, "_use_gradient_checkpointing", False) and not isinstance(
680
+ module, CheckpointWrapper
681
+ ):
682
+ return False
683
+ if module is self.vision_tokenizer:
684
+ return True
685
+ if isinstance(module, decoder_block_class):
686
+ return True
687
+
688
+ return lambda_fn
689
+
690
+ def get_fsdp_wrapping_policy(self):
691
+ """
692
+ Returns the policy used to decide how to perform FSDP wrapping.
693
+ """
694
+ from torch.distributed.fsdp.wrap import _or_policy, _module_wrap_policy, transformer_auto_wrap_policy
695
+ from open_clip.transformer import VisionTransformer, ResidualAttentionBlock
696
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
697
+ from transformers.models.phi.modeling_phi import PhiDecoderLayer
698
+ # for Phi-3 hot fiix
699
+ try:
700
+ import importlib
701
+ commit_hash = str(type(self.lang_model)).split('instruct.')[1].split('.modeling')[0]
702
+ module_name = f"transformers_modules.microsoft.Phi-3-mini-128k-instruct.{commit_hash}.modeling_phi3"
703
+ module = importlib.import_module(module_name)
704
+ Phi3DecoderLayer = module.Phi3DecoderLayer
705
+ import_phi3 = True
706
+ except IndexError:
707
+ import_phi3 = False
708
+
709
+
710
+ # hard code the wrap module name
711
+ # vision
712
+ if isinstance(self.vision_encoder, SiglipVisionModel):
713
+ from transformers import SiglipVisionModel
714
+ vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={SiglipVisionModel})
715
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer, SiglipVisionTransformer, SiglipVisionEmbeddings, SiglipMultiheadAttentionPoolingHead
716
+ # import torch.nn.LayerNorm as LayerNorm
717
+ transformer_layer_cls_vit = {SiglipEncoderLayer, SiglipVisionTransformer, SiglipVisionEmbeddings, SiglipMultiheadAttentionPoolingHead}
718
+ vision_transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls_vit)
719
+ vision_wrap_policy = functools.partial(_or_policy, policies=[vit_wrap_policy, vision_transformer_block_policy])
720
+
721
+ else:
722
+ vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={VisionTransformer, TimmModel})
723
+ # vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={VisionTransformer})
724
+ # transformer_layer_cls_vit = {ResidualAttentionBlock}
725
+ transformer_layer_cls_vit = {ResidualAttentionBlock, Block}
726
+ # transformer_layer_cls_vit = {Block}
727
+ vision_transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls_vit)
728
+ vision_wrap_policy = functools.partial(_or_policy, policies=[vit_wrap_policy, vision_transformer_block_policy])
729
+ # llm
730
+ transformer_layer_cls={LlamaDecoderLayer, PhiDecoderLayer}
731
+ if import_phi3:
732
+ transformer_layer_cls.add(Phi3DecoderLayer)
733
+ llm_transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls)
734
+ # vision_tokenizer
735
+ vis_tokenizer_policy = functools.partial(_module_wrap_policy, module_classes={LinearPatchProjection, PerceiverResampler})
736
+ return functools.partial(
737
+ _or_policy,
738
+ policies = [
739
+ vision_wrap_policy,
740
+ llm_transformer_block_policy,
741
+ vis_tokenizer_policy
742
+ ])
743
+
744
+ def group_params_by_weight_decay(self):
745
+ """
746
+ Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
747
+ """
748
+ params_with_wd, params_without_wd = [], []
749
+ for n, p in self.named_parameters():
750
+ if p.requires_grad:
751
+ if "lang_model.model.embed_tokens" in n:
752
+ params_without_wd.append(p)
753
+ else:
754
+ params_with_wd.append(p)
755
+ return params_with_wd, params_without_wd
756
+
757
+ @property
758
+ def num_params_per_module(self):
759
+ """Print the number of parameters per module in the model"""
760
+ return "\n".join(
761
+ [
762
+ f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
763
+ f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
764
+ f"Language model: {num_params(self.lang_model):,} parameters",
765
+ ]
766
+ )
767
+
768
+ @property
769
+ def num_trainable_params_per_module(self):
770
+ """Print the number of trainable parameters per module in the model"""
771
+ return "\n".join(
772
+ [
773
+ f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
774
+ f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
775
+ f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
776
+ ]
777
+ )