DeclanBracken commited on
Commit
b172869
·
verified ·
1 Parent(s): e41b165

Uploading helper functions from miniCPM-V-2.5 revision e978c4c9b177e8d1f36deeec20edb18377dc2ff7

Browse files

Uploading helper functions from the original repository to be used for instantiating the model with local code given that the snapshot of this repo is downloaded locally.

Files changed (3) hide show
  1. configuration_minicpm.py +113 -0
  2. modeling_minicpmv.py +702 -0
  3. resampler.py +813 -0
configuration_minicpm.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ MiniCPM model configuration"""
21
+ import os
22
+ from typing import Union
23
+
24
+ from transformers.utils import logging
25
+ from transformers import LlamaConfig, PretrainedConfig
26
+ from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionConfig
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class MiniCPMVSliceConfig(PretrainedConfig):
32
+ model_type = "minicpmv"
33
+
34
+ def __init__(
35
+ self,
36
+ patch_size=14,
37
+ max_slice_nums=9,
38
+ scale_resolution=448,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(**kwargs)
42
+ self.patch_size = patch_size
43
+ self.max_slice_nums = max_slice_nums
44
+ self.scale_resolution = scale_resolution
45
+
46
+ @classmethod
47
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
48
+ cls._set_token_in_kwargs(kwargs)
49
+
50
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
51
+
52
+ if config_dict.get("model_type") == "minicpmv":
53
+ config_dict = config_dict["slice_config"]
54
+
55
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
56
+ logger.warning(
57
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
58
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
59
+ )
60
+
61
+ return cls.from_dict(config_dict, **kwargs)
62
+
63
+
64
+
65
+ class MiniCPMVConfig(LlamaConfig):
66
+ model_type = "minicpmv"
67
+ keys_to_ignore_at_inference = ["past_key_values"]
68
+
69
+ default_vision_config = {
70
+ "hidden_size": 1152,
71
+ "image_size": 980,
72
+ "intermediate_size": 4304,
73
+ "model_type": "idefics2",
74
+ "num_attention_heads": 16,
75
+ "num_hidden_layers": 27,
76
+ "patch_size": 14,
77
+ }
78
+
79
+ def __init__(
80
+ self,
81
+ use_cache=True,
82
+ query_num=64,
83
+ image_size=448,
84
+ drop_vision_last_layer=True,
85
+ batch_vision_input=True,
86
+ slice_config=None,
87
+ vision_config=None,
88
+ **kwargs,
89
+ ):
90
+ self.use_cache = use_cache
91
+ self.query_num = query_num
92
+ self.image_size = image_size
93
+ self.drop_vision_last_layer = drop_vision_last_layer
94
+ self.batch_vision_input = batch_vision_input
95
+
96
+ if slice_config is None:
97
+ self.slice_config = MiniCPMVSliceConfig(max_slice_nums=1)
98
+ else:
99
+ self.slice_config = MiniCPMVSliceConfig(**slice_config)
100
+ self.slice_mode = True
101
+
102
+ # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
103
+ if vision_config is None:
104
+ self.vision_config = Idefics2VisionConfig(**self.default_vision_config)
105
+ logger.info("vision_config is None, using default vision config")
106
+ elif isinstance(vision_config, dict):
107
+ self.vision_config = Idefics2VisionConfig(**vision_config)
108
+ elif isinstance(vision_config, Idefics2VisionConfig):
109
+ self.vision_config = vision_config
110
+
111
+ self.patch_size = self.vision_config.patch_size
112
+
113
+ super().__init__(**kwargs)
modeling_minicpmv.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+ import json
4
+ import torch
5
+ import torchvision
6
+ from threading import Thread
7
+ from copy import deepcopy
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast, TextIteratorStreamer
11
+ from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
12
+
13
+ from .configuration_minicpm import MiniCPMVConfig
14
+ from .resampler import Resampler
15
+
16
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
17
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
18
+
19
+ class MiniCPMVPreTrainedModel(LlamaPreTrainedModel):
20
+ config_class = MiniCPMVConfig
21
+
22
+
23
+ class MiniCPMV(MiniCPMVPreTrainedModel):
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+
27
+ self.llm = LlamaForCausalLM(config)
28
+ self.vpm = self.init_vision_module()
29
+ self.vision_dim = self.vpm.embed_dim
30
+ self.embed_dim = self.llm.config.hidden_size
31
+ self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
32
+ self.transform = self.init_transform()
33
+
34
+ def init_vision_module(self):
35
+ # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
36
+ model = Idefics2VisionTransformer(self.config.vision_config)
37
+ if self.config.drop_vision_last_layer:
38
+ model.encoder.layers = model.encoder.layers[:-1]
39
+
40
+ setattr(model, 'embed_dim', model.embeddings.embed_dim)
41
+ setattr(model, 'patch_size', model.embeddings.patch_size)
42
+
43
+ return model
44
+
45
+ def init_resampler(self, embed_dim, vision_dim,):
46
+ return Resampler(
47
+ num_queries=self.config.query_num,
48
+ embed_dim=embed_dim,
49
+ num_heads=embed_dim // 128,
50
+ kv_dim=vision_dim,
51
+ adaptive=True,
52
+ )
53
+
54
+ def init_transform(self):
55
+ return transforms.Compose(
56
+ [
57
+ transforms.ToTensor(),
58
+ transforms.Normalize(
59
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
60
+ ),
61
+ ]
62
+ )
63
+
64
+ def get_input_embeddings(self):
65
+ return self.llm.get_input_embeddings()
66
+
67
+ def set_input_embeddings(self, value):
68
+ self.llm.embed_tokens = value
69
+
70
+ def get_vllm_embedding(self, data):
71
+ if 'vision_hidden_states' not in data:
72
+ dtype = self.llm.model.embed_tokens.weight.dtype
73
+ device = self.llm.model.embed_tokens.weight.device
74
+ tgt_sizes = data['tgt_sizes']
75
+ pixel_values_list = data['pixel_values']
76
+ vision_hidden_states = []
77
+ all_pixel_values = []
78
+ img_cnt = []
79
+ for pixel_values in pixel_values_list:
80
+ img_cnt.append(len(pixel_values))
81
+ all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
82
+
83
+ # exist image
84
+ if all_pixel_values:
85
+ tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
86
+
87
+ if self.config.batch_vision_input:
88
+ max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
89
+
90
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
91
+ padding_value=0.0)
92
+ B, L, _ = all_pixel_values.shape
93
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
94
+
95
+ patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
96
+ for i in range(B):
97
+ patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
98
+
99
+ vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state
100
+ vision_embedding = self.resampler(vision_embedding, tgt_sizes)
101
+ else:
102
+ # get vision_embedding foreach
103
+ vision_embedding = []
104
+ for single_tgt_size, single_pixel_values in zip(tgt_sizes, all_pixel_values):
105
+ single_pixel_values = single_pixel_values.unsqueeze(0)
106
+ B, L, _ = single_pixel_values.shape
107
+ single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
108
+ single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
109
+ single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
110
+
111
+ vision_embedding.append(single_vision_embedding)
112
+ vision_embedding = torch.vstack(vision_embedding)
113
+
114
+ start = 0
115
+ for pixel_values in pixel_values_list:
116
+ img_cnt = len(pixel_values)
117
+ if img_cnt > 0:
118
+ vision_hidden_states.append(vision_embedding[start: start + img_cnt])
119
+ start += img_cnt
120
+ else:
121
+ vision_hidden_states.append([])
122
+ else: # no image
123
+ if self.training:
124
+ dummy_image = torch.zeros(
125
+ (1, 3, 224, 224),
126
+ device=device, dtype=dtype
127
+ )
128
+ tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32)
129
+ dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
130
+ else:
131
+ dummy_feature = []
132
+ for _ in range(len(pixel_values_list)):
133
+ vision_hidden_states.append(dummy_feature)
134
+
135
+ else:
136
+ vision_hidden_states = data['vision_hidden_states']
137
+
138
+ if hasattr(self.llm.config, 'scale_emb'):
139
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
140
+ else:
141
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
142
+
143
+ vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
144
+ i, torch.Tensor) else i for i in vision_hidden_states]
145
+
146
+ bs = len(data['input_ids'])
147
+ for i in range(bs):
148
+ cur_vs_hs = vision_hidden_states[i]
149
+ if len(cur_vs_hs) > 0:
150
+ cur_vllm_emb = vllm_embedding[i]
151
+ cur_image_bound = data['image_bound'][i]
152
+ if len(cur_image_bound) > 0:
153
+ image_indices = torch.stack(
154
+ [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
155
+ ).to(vllm_embedding.device)
156
+ cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
157
+ cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
158
+ elif self.training:
159
+ cur_vllm_emb += cur_vs_hs[0].mean() * 0
160
+
161
+ return vllm_embedding, vision_hidden_states
162
+
163
+ def forward(self, data, **kwargs):
164
+ vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
165
+ position_ids = data["position_ids"]
166
+ if position_ids.dtype != torch.int64:
167
+ position_ids = position_ids.long()
168
+
169
+ return self.llm(
170
+ input_ids=None,
171
+ position_ids=position_ids,
172
+ inputs_embeds=vllm_embedding,
173
+ **kwargs
174
+ )
175
+
176
+ def _convert_to_tensors(
177
+ self, tokenizer, input_ids, max_inp_length: Optional[int] = None
178
+ ):
179
+ if max_inp_length is not None:
180
+ input_ids = input_ids[:max_inp_length]
181
+ input_ids = torch.tensor(input_ids, dtype=torch.int32)
182
+
183
+ image_start_tokens = torch.where(input_ids == tokenizer.im_start_id)[0]
184
+ # 跳过 im_start
185
+ image_start_tokens += 1
186
+ image_end_tokens = torch.where(input_ids == tokenizer.im_end_id)[0]
187
+ valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
188
+ image_bound = torch.hstack(
189
+ [
190
+ image_start_tokens[:valid_image_nums].unsqueeze(-1),
191
+ image_end_tokens[:valid_image_nums].unsqueeze(-1),
192
+ ]
193
+ )
194
+
195
+ model_input = {}
196
+ model_input["input_ids"] = input_ids.unsqueeze(0).to(self.device)
197
+ model_input["image_bound"] = image_bound
198
+
199
+ return model_input
200
+
201
+ def _process_list(
202
+ self, tokenizer, input_id_list, max_inp_length: Optional[int] = None
203
+ ):
204
+ pad_keys = ["input_ids"]
205
+ input_tensors = []
206
+ for input_ids in input_id_list:
207
+ input_tensors.append(
208
+ self._convert_to_tensors(tokenizer, input_ids, max_inp_length)
209
+ )
210
+ padded = {}
211
+ for key in pad_keys:
212
+ padded[key] = pad(input_tensors, key, padding_side="left").to(self.device)
213
+ padded["image_bound"] = [i["image_bound"] for i in input_tensors]
214
+ return padded
215
+
216
+ def _decode(self, inputs_embeds, tokenizer, **kwargs):
217
+ terminators = [
218
+ tokenizer.eos_token_id,
219
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
220
+ ]
221
+ output = self.llm.generate(
222
+ inputs_embeds=inputs_embeds,
223
+ pad_token_id=0,
224
+ eos_token_id=terminators,
225
+ **kwargs
226
+ )
227
+ return self._decode_text(output, tokenizer)
228
+
229
+ def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
230
+ terminators = [
231
+ tokenizer.eos_token_id,
232
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
233
+ ]
234
+ streamer = TextIteratorStreamer(tokenizer=tokenizer)
235
+ generation_kwargs = {
236
+ 'inputs_embeds': inputs_embeds,
237
+ 'pad_token_id': 0,
238
+ 'eos_token_id': terminators,
239
+ 'streamer': streamer
240
+ }
241
+ generation_kwargs.update(kwargs)
242
+
243
+ thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
244
+ thread.start()
245
+
246
+ return streamer
247
+
248
+ def _decode_text(self, result_ids, tokenizer):
249
+ result_text = []
250
+ for result in result_ids:
251
+ result = result[result != 0]
252
+ if result[0] == tokenizer.bos_id:
253
+ result = result[1:]
254
+ if result[-1] == tokenizer.eos_id or result[-1] == tokenizer.eot_id:
255
+ result = result[:-1]
256
+ result_text.append(tokenizer.decode(result).strip())
257
+ return result_text
258
+
259
+ def slice_image(self, image):
260
+ return slice_image(
261
+ image,
262
+ self.config.slice_config.max_slice_nums,
263
+ self.config.slice_config.scale_resolution,
264
+ self.config.slice_config.patch_size,
265
+ )
266
+
267
+ def get_slice_image_placeholder(self, image, tokenizer):
268
+ image_placeholder = (
269
+ tokenizer.im_start
270
+ + tokenizer.unk_token * self.config.query_num
271
+ + tokenizer.im_end
272
+ )
273
+
274
+ slice_images = []
275
+
276
+ source_image, patches, best_grid = slice_image(
277
+ image,
278
+ self.config.slice_config.max_slice_nums,
279
+ self.config.slice_config.scale_resolution,
280
+ self.config.slice_config.patch_size,
281
+ )
282
+
283
+ slice_images.append(source_image)
284
+ final_placeholder = image_placeholder
285
+
286
+ if len(patches) > 0:
287
+ for i in range(len(patches)):
288
+ for j in range(len(patches[0])):
289
+ slice_images.append(patches[i][j])
290
+
291
+ final_placeholder += get_grid_placeholder(
292
+ tokenizer, best_grid, self.config.query_num
293
+ )
294
+
295
+ return slice_images, final_placeholder
296
+
297
+ def reshape_by_patch(self, image_tensor):
298
+ """
299
+ :param image_tensor: shape [3, H, W]
300
+ :param patch_size:
301
+ :return: [3, patch_size, HW/patch_size]
302
+ """
303
+ patch_size = self.config.patch_size
304
+ patches = torch.nn.functional.unfold(
305
+ image_tensor,
306
+ (patch_size, patch_size),
307
+ stride=(patch_size, patch_size)
308
+ )
309
+
310
+ patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
311
+ patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1)
312
+ return patches
313
+
314
+ def generate(
315
+ self,
316
+ input_id_list=None,
317
+ img_list=None,
318
+ tgt_sizes=None,
319
+ tokenizer=None,
320
+ max_inp_length: Optional[int] = None,
321
+ vision_hidden_states=None,
322
+ return_vision_hidden_states=False,
323
+ stream=False,
324
+ **kwargs
325
+ ):
326
+
327
+ assert input_id_list is not None
328
+ bs = len(input_id_list)
329
+ if img_list == None:
330
+ img_list = [[] for i in range(bs)]
331
+ assert bs == len(img_list)
332
+
333
+ model_inputs = self._process_list(tokenizer, input_id_list, max_inp_length)
334
+
335
+ if vision_hidden_states is None:
336
+ pixel_values = []
337
+ for i in range(bs):
338
+ img_inps = []
339
+ for img in img_list[i]:
340
+ img_inps.append(img.to(self.device))
341
+ if img_inps:
342
+ pixel_values.append(img_inps)
343
+ else:
344
+ pixel_values.append([])
345
+ model_inputs["pixel_values"] = pixel_values
346
+ model_inputs['tgt_sizes'] = tgt_sizes
347
+ else:
348
+ model_inputs["vision_hidden_states"] = vision_hidden_states
349
+
350
+ with torch.inference_mode():
351
+ (
352
+ model_inputs["inputs_embeds"],
353
+ vision_hidden_states,
354
+ ) = self.get_vllm_embedding(model_inputs)
355
+
356
+ if stream:
357
+ result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
358
+ else:
359
+ result = self._decode(model_inputs["inputs_embeds"], tokenizer, **kwargs)
360
+
361
+ if return_vision_hidden_states:
362
+ return result, vision_hidden_states
363
+
364
+ return result
365
+
366
+ def chat(
367
+ self,
368
+ image,
369
+ msgs,
370
+ tokenizer,
371
+ vision_hidden_states=None,
372
+ max_new_tokens=1024,
373
+ sampling=True,
374
+ max_inp_length=2048,
375
+ system_prompt='',
376
+ stream=False,
377
+ **kwargs
378
+ ):
379
+ if isinstance(msgs, str):
380
+ msgs = json.loads(msgs)
381
+
382
+ copy_msgs = deepcopy(msgs)
383
+ assert len(copy_msgs) > 0, 'msgs is empty'
384
+ assert sampling or not stream, 'if use stream mode, make sure sampling=True'
385
+
386
+ if image is not None and isinstance(copy_msgs[0]['content'], str):
387
+ copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
388
+
389
+ images = []
390
+ tgt_sizes = []
391
+ for i, msg in enumerate(copy_msgs):
392
+ role = msg["role"]
393
+ content = msg["content"]
394
+ assert role in ["user", "assistant"]
395
+ if i == 0:
396
+ assert role == "user", "The role of first msg should be user"
397
+ if isinstance(content, str):
398
+ content = [content]
399
+
400
+ cur_msgs = []
401
+ for c in content:
402
+ if isinstance(c, Image.Image):
403
+ image = c
404
+ if self.config.slice_mode:
405
+ slice_images, image_placeholder = self.get_slice_image_placeholder(
406
+ image, tokenizer
407
+ )
408
+ cur_msgs.append(image_placeholder)
409
+ for slice_image in slice_images:
410
+ slice_image = self.transform(slice_image)
411
+ H, W = slice_image.shape[1:]
412
+ images.append(self.reshape_by_patch(slice_image))
413
+ tgt_sizes.append(torch.Tensor([H // self.config.patch_size, W // self.config.patch_size]).type(torch.int32))
414
+ else:
415
+ images.append(self.transform(image))
416
+ cur_msgs.append(
417
+ tokenizer.im_start
418
+ + tokenizer.unk_token * self.config.query_num
419
+ + tokenizer.im_end
420
+ )
421
+ elif isinstance(c, str):
422
+ cur_msgs.append(c)
423
+
424
+
425
+ msg['content'] = '\n'.join(cur_msgs)
426
+ if tgt_sizes:
427
+ tgt_sizes = torch.vstack(tgt_sizes)
428
+
429
+ if system_prompt:
430
+ sys_msg = {'role': 'system', 'content': system_prompt}
431
+ copy_msgs = [sys_msg] + copy_msgs
432
+
433
+ input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
434
+
435
+ if sampling:
436
+ generation_config = {
437
+ "top_p": 0.8,
438
+ "top_k": 100,
439
+ "temperature": 0.7,
440
+ "do_sample": True,
441
+ "repetition_penalty": 1.05
442
+ }
443
+ else:
444
+ generation_config = {
445
+ "num_beams": 3,
446
+ "repetition_penalty": 1.2,
447
+ }
448
+
449
+ generation_config.update(
450
+ (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
451
+ )
452
+
453
+ with torch.inference_mode():
454
+ res, vision_hidden_states = self.generate(
455
+ input_id_list=[input_ids],
456
+ max_inp_length=max_inp_length,
457
+ img_list=[images],
458
+ tgt_sizes=[tgt_sizes],
459
+ tokenizer=tokenizer,
460
+ max_new_tokens=max_new_tokens,
461
+ vision_hidden_states=vision_hidden_states,
462
+ return_vision_hidden_states=True,
463
+ stream=stream,
464
+ **generation_config
465
+ )
466
+
467
+ if stream:
468
+ def stream_gen():
469
+ for text in res:
470
+ text = text.replace(tokenizer.eot_token, '').replace(tokenizer.eos_token, '')
471
+ yield text
472
+ return stream_gen()
473
+
474
+ else:
475
+ answer = res[0]
476
+ return answer
477
+
478
+
479
+ class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast):
480
+ def __init__(self, **kwargs):
481
+ super().__init__(**kwargs)
482
+ self.eot_token = "<|eot_id|>"
483
+ self.im_start = "<image>"
484
+ self.im_end = "</image>"
485
+ self.ref_start = "<ref>"
486
+ self.ref_end = "</ref>"
487
+ self.box_start = "<box>"
488
+ self.box_end = "</box>"
489
+ self.quad_start = "<quad>"
490
+ self.quad_end = "</quad>"
491
+ self.slice_start = "<slice>"
492
+ self.slice_end = "</slice>"
493
+
494
+ @property
495
+ def eos_id(self):
496
+ return self.eos_token_id
497
+
498
+ @property
499
+ def bos_id(self):
500
+ return self.bos_token_id
501
+
502
+ @property
503
+ def unk_id(self):
504
+ return self.unk_token_id
505
+
506
+ @property
507
+ def eot_id(self):
508
+ return self.convert_tokens_to_ids(self.eot_token)
509
+
510
+ @property
511
+ def im_start_id(self):
512
+ return self.convert_tokens_to_ids(self.im_start)
513
+
514
+ @property
515
+ def im_end_id(self):
516
+ return self.convert_tokens_to_ids(self.im_end)
517
+
518
+ @staticmethod
519
+ def escape(text: str) -> str:
520
+ return text
521
+
522
+ @staticmethod
523
+ def unescape(text: str) -> str:
524
+ return text
525
+
526
+
527
+ def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
528
+ items = []
529
+ if isinstance(orig_items[0][key], list):
530
+ assert isinstance(orig_items[0][key][0], torch.Tensor)
531
+ for it in orig_items:
532
+ for tr in it[key]:
533
+ items.append({key: tr})
534
+ else:
535
+ assert isinstance(orig_items[0][key], torch.Tensor)
536
+ items = orig_items
537
+
538
+ batch_size = len(items)
539
+ shape = items[0][key].shape
540
+ dim = len(shape)
541
+ assert dim <= 3
542
+ if max_length is None:
543
+ max_length = 0
544
+ max_length = max(max_length, max(item[key].shape[-1] for item in items))
545
+ min_length = min(item[key].shape[-1] for item in items)
546
+ dtype = items[0][key].dtype
547
+
548
+ if dim == 1:
549
+ return torch.cat([item[key] for item in items], dim=0)
550
+ elif dim == 2:
551
+ if max_length == min_length:
552
+ return torch.cat([item[key] for item in items], dim=0)
553
+ tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
554
+ else:
555
+ tensor = (
556
+ torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype)
557
+ + padding_value
558
+ )
559
+
560
+ for i, item in enumerate(items):
561
+ if dim == 2:
562
+ if padding_side == "left":
563
+ tensor[i, -len(item[key][0]) :] = item[key][0].clone()
564
+ else:
565
+ tensor[i, : len(item[key][0])] = item[key][0].clone()
566
+ elif dim == 3:
567
+ if padding_side == "left":
568
+ tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
569
+ else:
570
+ tensor[i, : len(item[key][0]), :] = item[key][0].clone()
571
+
572
+ return tensor
573
+
574
+
575
+ def slice_image(
576
+ image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
577
+ ):
578
+ original_size = image.size
579
+ original_width, original_height = original_size
580
+ log_ratio = math.log(original_width / original_height)
581
+ ratio = original_width * original_height / (scale_resolution * scale_resolution)
582
+ multiple = min(math.ceil(ratio), max_slice_nums)
583
+
584
+ source_image = None
585
+ best_grid = None
586
+ patches = []
587
+
588
+ if multiple <= 1 or never_split:
589
+ # dont need to slice, upsample
590
+ best_size = find_best_resize(
591
+ original_size, scale_resolution, patch_size, allow_upscale=True
592
+ )
593
+ source_image = image.resize(best_size, Image.Resampling.BICUBIC)
594
+ else:
595
+ candidate_split_grids_nums = []
596
+ for i in [multiple - 1, multiple, multiple + 1]:
597
+ if i == 1 or i > max_slice_nums:
598
+ continue
599
+ candidate_split_grids_nums.append(i)
600
+
601
+ # source image, down-sampling and ensure divided by patch_size
602
+ best_resize = find_best_resize(original_size, scale_resolution, patch_size)
603
+ source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
604
+ candidate_grids = []
605
+
606
+ # find best grid
607
+ for split_grids_nums in candidate_split_grids_nums:
608
+ m = 1
609
+ while m <= split_grids_nums:
610
+ if split_grids_nums % m == 0:
611
+ candidate_grids.append([m, split_grids_nums // m])
612
+ m += 1
613
+
614
+ best_grid = [1, 1]
615
+ min_error = float("inf")
616
+ for grid in candidate_grids:
617
+ error = abs(log_ratio - math.log(grid[0] / grid[1]))
618
+ if error < min_error:
619
+ best_grid = grid
620
+ min_error = error
621
+
622
+ refine_size = get_refine_size(
623
+ original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
624
+ )
625
+
626
+ refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
627
+ patches = split_to_patches(refine_image, best_grid)
628
+
629
+ return source_image, patches, best_grid
630
+
631
+
632
+ def ensure_divide(length, patch_size):
633
+ return max(round(length / patch_size) * patch_size, patch_size)
634
+
635
+
636
+ def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):
637
+ width, height = original_size
638
+ if (width * height > scale_resolution * scale_resolution) or allow_upscale:
639
+ r = width / height
640
+ height = int(scale_resolution / math.sqrt(r))
641
+ width = int(height * r)
642
+ best_width = ensure_divide(width, patch_size)
643
+ best_height = ensure_divide(height, patch_size)
644
+ return (best_width, best_height)
645
+
646
+
647
+ def get_refine_size(
648
+ original_size, grid, scale_resolution, patch_size, allow_upscale=False
649
+ ):
650
+ width, height = original_size
651
+ grid_x, grid_y = grid
652
+
653
+ refine_width = ensure_divide(width, grid_x)
654
+ refine_height = ensure_divide(height, grid_y)
655
+
656
+ grid_width = refine_width / grid_x
657
+ grid_height = refine_height / grid_y
658
+
659
+ best_grid_size = find_best_resize(
660
+ (grid_width, grid_height),
661
+ scale_resolution,
662
+ patch_size,
663
+ allow_upscale=allow_upscale,
664
+ )
665
+
666
+ refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
667
+
668
+ return refine_size
669
+
670
+
671
+ def split_to_patches(image, grid):
672
+ patches = []
673
+ width, height = image.size
674
+ grid_x = int(width / grid[0])
675
+ grid_y = int(height / grid[1])
676
+
677
+ for i in range(0, height, grid_y):
678
+ images = []
679
+ for j in range(0, width, grid_x):
680
+ box = (j, i, j + grid_x, i + grid_y)
681
+ patch = image.crop(box)
682
+ images.append(patch)
683
+ patches.append(images)
684
+
685
+ return patches
686
+
687
+
688
+ def get_grid_placeholder(tokenizer, grid, query_num):
689
+ image_placeholder = (
690
+ tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
691
+ )
692
+
693
+ cols = grid[0]
694
+ rows = grid[1]
695
+ slices = []
696
+ for i in range(rows):
697
+ lines = []
698
+ for j in range(cols):
699
+ lines.append(image_placeholder)
700
+ slices.append("".join(lines))
701
+ slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end
702
+ return slice_placeholder
resampler.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import numpy as np
3
+ import warnings
4
+ from typing import Optional, Tuple
5
+ import torch
6
+ from torch import nn
7
+ from torch import Tensor
8
+ import deepspeed
9
+ import torch.nn.functional as F
10
+ from torch.nn.functional import *
11
+ from torch.nn.modules.activation import *
12
+ from torch.nn.init import trunc_normal_
13
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
14
+ from transformers import PreTrainedModel
15
+ from transformers.integrations import is_deepspeed_zero3_enabled
16
+
17
+ def get_2d_sincos_pos_embed(embed_dim, image_size):
18
+ """
19
+ image_size: image_size or (image_height, image_width)
20
+ return:
21
+ pos_embed: [image_height, image_width, embed_dim]
22
+ """
23
+ if isinstance(image_size, int):
24
+ grid_h_size, grid_w_size = image_size, image_size
25
+ else:
26
+ grid_h_size, grid_w_size = image_size[0], image_size[1]
27
+
28
+ grid_h = np.arange(grid_h_size, dtype=np.float32)
29
+ grid_w = np.arange(grid_w_size, dtype=np.float32)
30
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
31
+ grid = np.stack(grid, axis=0)
32
+
33
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
34
+ return pos_embed
35
+
36
+
37
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
38
+ assert embed_dim % 2 == 0
39
+
40
+ # use half of dimensions to encode grid_h
41
+ emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0]) # (H, W, D/2)
42
+ emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1]) # (H, W, D/2)
43
+
44
+ emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
45
+ return emb
46
+
47
+
48
+ def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
49
+ """
50
+ embed_dim: output dimension for each position
51
+ pos: a list of positions to be encoded: size (H, W)
52
+ out: (H, W, D)
53
+ """
54
+ assert embed_dim % 2 == 0
55
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
56
+ omega /= embed_dim / 2.
57
+ omega = 1. / 10000 ** omega # (D/2,)
58
+
59
+ out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
60
+
61
+ emb_sin = np.sin(out) # (H, W, D/2)
62
+ emb_cos = np.cos(out) # (H, W, D/2)
63
+
64
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
65
+ return emb
66
+
67
+
68
+ class Resampler(nn.Module):
69
+ """
70
+ A 2D perceiver-resampler network with one cross attention layers by
71
+ given learnable queries and 2d sincos pos_emb
72
+ Outputs:
73
+ A tensor with the shape of (batch_size, num_queries, embed_dim)
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ num_queries,
79
+ embed_dim,
80
+ num_heads,
81
+ kv_dim=None,
82
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
83
+ adaptive=False,
84
+ max_size=(70, 70),
85
+ ):
86
+ super().__init__()
87
+ self.num_queries = num_queries
88
+ self.embed_dim = embed_dim
89
+ self.num_heads = num_heads
90
+ self.adaptive = adaptive
91
+ self.max_size = max_size
92
+
93
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
94
+
95
+ if kv_dim is not None and kv_dim != embed_dim:
96
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
97
+ else:
98
+ self.kv_proj = nn.Identity()
99
+
100
+ self.attn = MultiheadAttention(embed_dim, num_heads)
101
+ self.ln_q = norm_layer(embed_dim)
102
+ self.ln_kv = norm_layer(embed_dim)
103
+
104
+ self.ln_post = norm_layer(embed_dim)
105
+ self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
106
+
107
+ self._set_2d_pos_cache(self.max_size)
108
+
109
+ def _set_2d_pos_cache(self, max_size, device='cpu'):
110
+ if is_deepspeed_zero3_enabled():
111
+ device='cuda'
112
+ pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device)
113
+ self.register_buffer("pos_embed", pos_embed, persistent=False)
114
+
115
+ def _adjust_pos_cache(self, tgt_sizes, device):
116
+ max_h = torch.max(tgt_sizes[:, 0])
117
+ max_w = torch.max(tgt_sizes[:, 1])
118
+ if max_h > self.max_size[0] or max_w > self.max_size[1]:
119
+ self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
120
+ self._set_2d_pos_cache(self.max_size, device)
121
+
122
+ def _init_weights(self, m):
123
+ if isinstance(m, nn.Linear):
124
+ trunc_normal_(m.weight, std=.02)
125
+ if isinstance(m, nn.Linear) and m.bias is not None:
126
+ nn.init.constant_(m.bias, 0)
127
+ elif isinstance(m, nn.LayerNorm):
128
+ nn.init.constant_(m.bias, 0)
129
+ nn.init.constant_(m.weight, 1.0)
130
+
131
+ def forward(self, x, tgt_sizes=None):
132
+ assert x.shape[0] == tgt_sizes.shape[0]
133
+ bs = x.shape[0]
134
+
135
+ device = x.device
136
+ dtype = x.dtype
137
+
138
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
139
+
140
+ self._adjust_pos_cache(tgt_sizes, device=device)
141
+
142
+ max_patch_len = torch.max(patch_len)
143
+ key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)
144
+
145
+ pos_embed = []
146
+ for i in range(bs):
147
+ tgt_h, tgt_w = tgt_sizes[i]
148
+ pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
149
+ key_padding_mask[i, patch_len[i]:] = True
150
+
151
+ pos_embed = torch.nn.utils.rnn.pad_sequence(
152
+ pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D
153
+
154
+ x = self.kv_proj(x) # B * L * D
155
+ x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
156
+
157
+ q = self.ln_q(self.query) # Q * D
158
+
159
+ out = self.attn(
160
+ self._repeat(q, bs), # Q * B * D
161
+ x + pos_embed, # L * B * D + L * B * D
162
+ x,
163
+ key_padding_mask=key_padding_mask)[0]
164
+ # out: Q * B * D
165
+ x = out.permute(1, 0, 2) # B * Q * D
166
+
167
+ x = self.ln_post(x)
168
+ x = x @ self.proj
169
+ return x
170
+
171
+ def _repeat(self, query, N: int):
172
+ return query.unsqueeze(1).repeat(1, N, 1)
173
+
174
+
175
+ class MultiheadAttention(nn.MultiheadAttention):
176
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
177
+ add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None):
178
+ super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype)
179
+
180
+ # rewrite out_proj layer,with nn.Linear
181
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
182
+
183
+ def forward(
184
+ self,
185
+ query: Tensor,
186
+ key: Tensor,
187
+ value: Tensor,
188
+ key_padding_mask: Optional[Tensor] = None,
189
+ need_weights: bool = True,
190
+ attn_mask: Optional[Tensor] = None,
191
+ average_attn_weights: bool = True,
192
+ is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
193
+ why_not_fast_path = ''
194
+ if ((attn_mask is not None and torch.is_floating_point(attn_mask))
195
+ or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
196
+ why_not_fast_path = "floating-point masks are not supported for fast path."
197
+
198
+ is_batched = query.dim() == 3
199
+
200
+ key_padding_mask = F._canonical_mask(
201
+ mask=key_padding_mask,
202
+ mask_name="key_padding_mask",
203
+ other_type=F._none_or_dtype(attn_mask),
204
+ other_name="attn_mask",
205
+ target_type=query.dtype
206
+ )
207
+
208
+ attn_mask = F._canonical_mask(
209
+ mask=attn_mask,
210
+ mask_name="attn_mask",
211
+ other_type=None,
212
+ other_name="",
213
+ target_type=query.dtype,
214
+ check_other=False,
215
+ )
216
+
217
+
218
+ if not is_batched:
219
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
220
+ elif query is not key or key is not value:
221
+ # When lifting this restriction, don't forget to either
222
+ # enforce that the dtypes all match or test cases where
223
+ # they don't!
224
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
225
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
226
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
227
+ elif self.in_proj_weight is None:
228
+ why_not_fast_path = "in_proj_weight was None"
229
+ elif query.dtype != self.in_proj_weight.dtype:
230
+ # this case will fail anyway, but at least they'll get a useful error message.
231
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
232
+ elif self.training:
233
+ why_not_fast_path = "training is enabled"
234
+ elif (self.num_heads % 2) != 0:
235
+ why_not_fast_path = "self.num_heads is not even"
236
+ elif not self.batch_first:
237
+ why_not_fast_path = "batch_first was not True"
238
+ elif self.bias_k is not None:
239
+ why_not_fast_path = "self.bias_k was not None"
240
+ elif self.bias_v is not None:
241
+ why_not_fast_path = "self.bias_v was not None"
242
+ elif self.add_zero_attn:
243
+ why_not_fast_path = "add_zero_attn was enabled"
244
+ elif not self._qkv_same_embed_dim:
245
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
246
+ elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
247
+ why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
248
+ is not supported with NestedTensor input"
249
+ elif torch.is_autocast_enabled():
250
+ why_not_fast_path = "autocast is enabled"
251
+
252
+ if not why_not_fast_path:
253
+ tensor_args = (
254
+ query,
255
+ key,
256
+ value,
257
+ self.in_proj_weight,
258
+ self.in_proj_bias,
259
+ self.out_proj.weight,
260
+ self.out_proj.bias,
261
+ )
262
+ # We have to use list comprehensions below because TorchScript does not support
263
+ # generator expressions.
264
+ if torch.overrides.has_torch_function(tensor_args):
265
+ why_not_fast_path = "some Tensor argument has_torch_function"
266
+ elif _is_make_fx_tracing():
267
+ why_not_fast_path = "we are running make_fx tracing"
268
+ elif not all(_check_arg_device(x) for x in tensor_args):
269
+ why_not_fast_path = ("some Tensor argument's device is neither one of "
270
+ f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
271
+ elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
272
+ why_not_fast_path = ("grad is enabled and at least one of query or the "
273
+ "input/output projection weights or biases requires_grad")
274
+ if not why_not_fast_path:
275
+ merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
276
+
277
+ if self.in_proj_bias is not None and self.in_proj_weight is not None:
278
+ return torch._native_multi_head_attention(
279
+ query,
280
+ key,
281
+ value,
282
+ self.embed_dim,
283
+ self.num_heads,
284
+ self.in_proj_weight,
285
+ self.in_proj_bias,
286
+ self.out_proj.weight,
287
+ self.out_proj.bias,
288
+ merged_mask,
289
+ need_weights,
290
+ average_attn_weights,
291
+ mask_type)
292
+
293
+ any_nested = query.is_nested or key.is_nested or value.is_nested
294
+ assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
295
+ f"The fast path was not hit because {why_not_fast_path}")
296
+
297
+ if self.batch_first and is_batched:
298
+ # make sure that the transpose op does not affect the "is" property
299
+ if key is value:
300
+ if query is key:
301
+ query = key = value = query.transpose(1, 0)
302
+ else:
303
+ query, key = (x.transpose(1, 0) for x in (query, key))
304
+ value = key
305
+ else:
306
+ query, key, value = (x.transpose(1, 0) for x in (query, key, value))
307
+
308
+ if not self._qkv_same_embed_dim:
309
+ attn_output, attn_output_weights = self.multi_head_attention_forward(
310
+ query, key, value, self.embed_dim, self.num_heads,
311
+ self.in_proj_weight, self.in_proj_bias,
312
+ self.bias_k, self.bias_v, self.add_zero_attn,
313
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
314
+ training=self.training,
315
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
316
+ attn_mask=attn_mask,
317
+ use_separate_proj_weight=True,
318
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
319
+ v_proj_weight=self.v_proj_weight,
320
+ average_attn_weights=average_attn_weights,
321
+ is_causal=is_causal)
322
+ else:
323
+ attn_output, attn_output_weights = self.multi_head_attention_forward(
324
+ query, key, value, self.embed_dim, self.num_heads,
325
+ self.in_proj_weight, self.in_proj_bias,
326
+ self.bias_k, self.bias_v, self.add_zero_attn,
327
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
328
+ training=self.training,
329
+ key_padding_mask=key_padding_mask,
330
+ need_weights=need_weights,
331
+ attn_mask=attn_mask,
332
+ average_attn_weights=average_attn_weights,
333
+ is_causal=is_causal)
334
+ if self.batch_first and is_batched:
335
+ return attn_output.transpose(1, 0), attn_output_weights
336
+ else:
337
+ return attn_output, attn_output_weights
338
+
339
+ def multi_head_attention_forward(
340
+ self,
341
+ query: Tensor,
342
+ key: Tensor,
343
+ value: Tensor,
344
+ embed_dim_to_check: int,
345
+ num_heads: int,
346
+ in_proj_weight: Optional[Tensor],
347
+ in_proj_bias: Optional[Tensor],
348
+ bias_k: Optional[Tensor],
349
+ bias_v: Optional[Tensor],
350
+ add_zero_attn: bool,
351
+ dropout_p: float,
352
+ out_proj_weight: Tensor,
353
+ out_proj_bias: Optional[Tensor],
354
+ training: bool = True,
355
+ key_padding_mask: Optional[Tensor] = None,
356
+ need_weights: bool = True,
357
+ attn_mask: Optional[Tensor] = None,
358
+ use_separate_proj_weight: bool = False,
359
+ q_proj_weight: Optional[Tensor] = None,
360
+ k_proj_weight: Optional[Tensor] = None,
361
+ v_proj_weight: Optional[Tensor] = None,
362
+ static_k: Optional[Tensor] = None,
363
+ static_v: Optional[Tensor] = None,
364
+ average_attn_weights: bool = True,
365
+ is_causal: bool = False,
366
+ ) -> Tuple[Tensor, Optional[Tensor]]:
367
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
368
+ if has_torch_function(tens_ops):
369
+ return handle_torch_function(
370
+ multi_head_attention_forward,
371
+ tens_ops,
372
+ query,
373
+ key,
374
+ value,
375
+ embed_dim_to_check,
376
+ num_heads,
377
+ in_proj_weight,
378
+ in_proj_bias,
379
+ bias_k,
380
+ bias_v,
381
+ add_zero_attn,
382
+ dropout_p,
383
+ out_proj_weight,
384
+ out_proj_bias,
385
+ training=training,
386
+ key_padding_mask=key_padding_mask,
387
+ need_weights=need_weights,
388
+ attn_mask=attn_mask,
389
+ is_causal=is_causal,
390
+ use_separate_proj_weight=use_separate_proj_weight,
391
+ q_proj_weight=q_proj_weight,
392
+ k_proj_weight=k_proj_weight,
393
+ v_proj_weight=v_proj_weight,
394
+ static_k=static_k,
395
+ static_v=static_v,
396
+ average_attn_weights=average_attn_weights,
397
+ )
398
+
399
+ is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
400
+
401
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
402
+ # is batched, run the computation and before returning squeeze the
403
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
404
+ if not is_batched:
405
+ # unsqueeze if the input is unbatched
406
+ query = query.unsqueeze(1)
407
+ key = key.unsqueeze(1)
408
+ value = value.unsqueeze(1)
409
+ if key_padding_mask is not None:
410
+ key_padding_mask = key_padding_mask.unsqueeze(0)
411
+
412
+ # set up shape vars
413
+ tgt_len, bsz, embed_dim = query.shape
414
+ src_len, _, _ = key.shape
415
+
416
+ key_padding_mask = _canonical_mask(
417
+ mask=key_padding_mask,
418
+ mask_name="key_padding_mask",
419
+ other_type=_none_or_dtype(attn_mask),
420
+ other_name="attn_mask",
421
+ target_type=query.dtype
422
+ )
423
+
424
+ if is_causal and attn_mask is None:
425
+ raise RuntimeError(
426
+ "Need attn_mask if specifying the is_causal hint. "
427
+ "You may use the Transformer module method "
428
+ "`generate_square_subsequent_mask` to create this mask."
429
+ )
430
+
431
+ if is_causal and key_padding_mask is None and not need_weights:
432
+ # when we have a kpm or need weights, we need attn_mask
433
+ # Otherwise, we use the is_causal hint go as is_causal
434
+ # indicator to SDPA.
435
+ attn_mask = None
436
+ else:
437
+ attn_mask = _canonical_mask(
438
+ mask=attn_mask,
439
+ mask_name="attn_mask",
440
+ other_type=None,
441
+ other_name="",
442
+ target_type=query.dtype,
443
+ check_other=False,
444
+ )
445
+
446
+ if key_padding_mask is not None:
447
+ # We have the attn_mask, and use that to merge kpm into it.
448
+ # Turn off use of is_causal hint, as the merged mask is no
449
+ # longer causal.
450
+ is_causal = False
451
+
452
+ assert embed_dim == embed_dim_to_check, \
453
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
454
+ if isinstance(embed_dim, torch.Tensor):
455
+ # embed_dim can be a tensor when JIT tracing
456
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
457
+ else:
458
+ head_dim = embed_dim // num_heads
459
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
460
+ if use_separate_proj_weight:
461
+ # allow MHA to have different embedding dimensions when separate projection weights are used
462
+ assert key.shape[:2] == value.shape[:2], \
463
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
464
+ else:
465
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
466
+
467
+ #
468
+ # compute in-projection
469
+ #
470
+ if not use_separate_proj_weight:
471
+ assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
472
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
473
+ else:
474
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
475
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
476
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
477
+ if in_proj_bias is None:
478
+ b_q = b_k = b_v = None
479
+ else:
480
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
481
+ q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
482
+
483
+ # prep attention mask
484
+
485
+ if attn_mask is not None:
486
+ # ensure attn_mask's dim is 3
487
+ if attn_mask.dim() == 2:
488
+ correct_2d_size = (tgt_len, src_len)
489
+ if attn_mask.shape != correct_2d_size:
490
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
491
+ attn_mask = attn_mask.unsqueeze(0)
492
+ elif attn_mask.dim() == 3:
493
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
494
+ if attn_mask.shape != correct_3d_size:
495
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
496
+ else:
497
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
498
+
499
+ # add bias along batch dimension (currently second)
500
+ if bias_k is not None and bias_v is not None:
501
+ assert static_k is None, "bias cannot be added to static key."
502
+ assert static_v is None, "bias cannot be added to static value."
503
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
504
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
505
+ if attn_mask is not None:
506
+ attn_mask = pad(attn_mask, (0, 1))
507
+ if key_padding_mask is not None:
508
+ key_padding_mask = pad(key_padding_mask, (0, 1))
509
+ else:
510
+ assert bias_k is None
511
+ assert bias_v is None
512
+
513
+ #
514
+ # reshape q, k, v for multihead attention and make em batch first
515
+ #
516
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
517
+ if static_k is None:
518
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
519
+ else:
520
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
521
+ assert static_k.size(0) == bsz * num_heads, \
522
+ f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
523
+ assert static_k.size(2) == head_dim, \
524
+ f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
525
+ k = static_k
526
+ if static_v is None:
527
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
528
+ else:
529
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
530
+ assert static_v.size(0) == bsz * num_heads, \
531
+ f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
532
+ assert static_v.size(2) == head_dim, \
533
+ f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
534
+ v = static_v
535
+
536
+ # add zero attention along batch dimension (now first)
537
+ if add_zero_attn:
538
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
539
+ k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
540
+ v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
541
+ if attn_mask is not None:
542
+ attn_mask = pad(attn_mask, (0, 1))
543
+ if key_padding_mask is not None:
544
+ key_padding_mask = pad(key_padding_mask, (0, 1))
545
+
546
+ # update source sequence length after adjustments
547
+ src_len = k.size(1)
548
+
549
+ # merge key padding and attention masks
550
+ if key_padding_mask is not None:
551
+ assert key_padding_mask.shape == (bsz, src_len), \
552
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
553
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
554
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
555
+ if attn_mask is None:
556
+ attn_mask = key_padding_mask
557
+ else:
558
+ attn_mask = attn_mask + key_padding_mask
559
+
560
+ # adjust dropout probability
561
+ if not training:
562
+ dropout_p = 0.0
563
+
564
+ #
565
+ # (deep breath) calculate attention and out projection
566
+ #
567
+
568
+ if need_weights:
569
+ B, Nt, E = q.shape
570
+ q_scaled = q / math.sqrt(E)
571
+
572
+ assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
573
+
574
+ if attn_mask is not None:
575
+ attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
576
+ else:
577
+ attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
578
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
579
+ if dropout_p > 0.0:
580
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
581
+
582
+ attn_output = torch.bmm(attn_output_weights, v)
583
+
584
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
585
+ attn_output = self.out_proj(attn_output)
586
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
587
+
588
+ # optionally average attention weights over heads
589
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
590
+ if average_attn_weights:
591
+ attn_output_weights = attn_output_weights.mean(dim=1)
592
+
593
+ if not is_batched:
594
+ # squeeze the output if input was unbatched
595
+ attn_output = attn_output.squeeze(1)
596
+ attn_output_weights = attn_output_weights.squeeze(0)
597
+ return attn_output, attn_output_weights
598
+ else:
599
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
600
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
601
+ # in order to match the input for SDPA of (N, num_heads, L, S)
602
+ if attn_mask is not None:
603
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
604
+ attn_mask = attn_mask.unsqueeze(0)
605
+ else:
606
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
607
+
608
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
609
+ k = k.view(bsz, num_heads, src_len, head_dim)
610
+ v = v.view(bsz, num_heads, src_len, head_dim)
611
+
612
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
613
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
614
+
615
+ attn_output = self.out_proj(attn_output)
616
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
617
+ if not is_batched:
618
+ # squeeze the output if input was unbatched
619
+ attn_output = attn_output.squeeze(1)
620
+ return attn_output, None
621
+
622
+
623
+ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
624
+ key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
625
+ # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
626
+ # and returns if the input is batched or not.
627
+ # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
628
+
629
+ # Shape check.
630
+ if query.dim() == 3:
631
+ # Batched Inputs
632
+ is_batched = True
633
+ assert key.dim() == 3 and value.dim() == 3, \
634
+ ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
635
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
636
+ if key_padding_mask is not None:
637
+ assert key_padding_mask.dim() == 2, \
638
+ ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
639
+ f" but found {key_padding_mask.dim()}-D tensor instead")
640
+ if attn_mask is not None:
641
+ assert attn_mask.dim() in (2, 3), \
642
+ ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
643
+ f" but found {attn_mask.dim()}-D tensor instead")
644
+ elif query.dim() == 2:
645
+ # Unbatched Inputs
646
+ is_batched = False
647
+ assert key.dim() == 2 and value.dim() == 2, \
648
+ ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
649
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
650
+
651
+ if key_padding_mask is not None:
652
+ assert key_padding_mask.dim() == 1, \
653
+ ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
654
+ f" but found {key_padding_mask.dim()}-D tensor instead")
655
+
656
+ if attn_mask is not None:
657
+ assert attn_mask.dim() in (2, 3), \
658
+ ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
659
+ f" but found {attn_mask.dim()}-D tensor instead")
660
+ if attn_mask.dim() == 3:
661
+ expected_shape = (num_heads, query.shape[0], key.shape[0])
662
+ assert attn_mask.shape == expected_shape, \
663
+ (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
664
+ else:
665
+ raise AssertionError(
666
+ f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
667
+
668
+ return is_batched
669
+
670
+
671
+ def _canonical_mask(
672
+ mask: Optional[Tensor],
673
+ mask_name: str,
674
+ other_type: Optional[DType],
675
+ other_name: str,
676
+ target_type: DType,
677
+ check_other: bool = True,
678
+ ) -> Optional[Tensor]:
679
+
680
+ if mask is not None:
681
+ _mask_dtype = mask.dtype
682
+ _mask_is_float = torch.is_floating_point(mask)
683
+ if _mask_dtype != torch.bool and not _mask_is_float:
684
+ raise AssertionError(
685
+ f"only bool and floating types of {mask_name} are supported")
686
+ if check_other and other_type is not None:
687
+ if _mask_dtype != other_type:
688
+ warnings.warn(
689
+ f"Support for mismatched {mask_name} and {other_name} "
690
+ "is deprecated. Use same type for both instead."
691
+ )
692
+ if not _mask_is_float:
693
+ mask = (
694
+ torch.zeros_like(mask, dtype=target_type)
695
+ .masked_fill_(mask, float("-inf"))
696
+ )
697
+ return mask
698
+
699
+
700
+ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
701
+ if input is None:
702
+ return None
703
+ elif isinstance(input, torch.Tensor):
704
+ return input.dtype
705
+ raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
706
+
707
+ def _in_projection_packed(
708
+ q: Tensor,
709
+ k: Tensor,
710
+ v: Tensor,
711
+ w: Tensor,
712
+ b: Optional[Tensor] = None,
713
+ ) -> List[Tensor]:
714
+ r"""
715
+ Performs the in-projection step of the attention operation, using packed weights.
716
+ Output is a triple containing projection tensors for query, key and value.
717
+ Args:
718
+ q, k, v: query, key and value tensors to be projected. For self-attention,
719
+ these are typically the same tensor; for encoder-decoder attention,
720
+ k and v are typically the same tensor. (We take advantage of these
721
+ identities for performance if they are present.) Regardless, q, k and v
722
+ must share a common embedding dimension; otherwise their shapes may vary.
723
+ w: projection weights for q, k and v, packed into a single tensor. Weights
724
+ are packed along dimension 0, in q, k, v order.
725
+ b: optional projection biases for q, k and v, packed into a single tensor
726
+ in q, k, v order.
727
+ Shape:
728
+ Inputs:
729
+ - q: :math:`(..., E)` where E is the embedding dimension
730
+ - k: :math:`(..., E)` where E is the embedding dimension
731
+ - v: :math:`(..., E)` where E is the embedding dimension
732
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
733
+ - b: :math:`E * 3` where E is the embedding dimension
734
+ Output:
735
+ - in output list :math:`[q', k', v']`, each output tensor will have the
736
+ same shape as the corresponding input tensor.
737
+ """
738
+ E = q.size(-1)
739
+ if k is v:
740
+ if q is k:
741
+ # self-attention
742
+ proj = linear(q, w, b)
743
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
744
+ proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
745
+ return proj[0], proj[1], proj[2]
746
+ else:
747
+ # encoder-decoder attention
748
+ w_q, w_kv = w.split([E, E * 2])
749
+ if b is None:
750
+ b_q = b_kv = None
751
+ else:
752
+ b_q, b_kv = b.split([E, E * 2])
753
+ q_proj = linear(q, w_q, b_q)
754
+ kv_proj = linear(k, w_kv, b_kv)
755
+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
756
+ kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
757
+ return (q_proj, kv_proj[0], kv_proj[1])
758
+ else:
759
+ w_q, w_k, w_v = w.chunk(3)
760
+ if b is None:
761
+ b_q = b_k = b_v = None
762
+ else:
763
+ b_q, b_k, b_v = b.chunk(3)
764
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
765
+
766
+
767
+ def _in_projection(
768
+ q: Tensor,
769
+ k: Tensor,
770
+ v: Tensor,
771
+ w_q: Tensor,
772
+ w_k: Tensor,
773
+ w_v: Tensor,
774
+ b_q: Optional[Tensor] = None,
775
+ b_k: Optional[Tensor] = None,
776
+ b_v: Optional[Tensor] = None,
777
+ ) -> Tuple[Tensor, Tensor, Tensor]:
778
+ r"""
779
+ Performs the in-projection step of the attention operation. This is simply
780
+ a triple of linear projections, with shape constraints on the weights which
781
+ ensure embedding dimension uniformity in the projected outputs.
782
+ Output is a triple containing projection tensors for query, key and value.
783
+ Args:
784
+ q, k, v: query, key and value tensors to be projected.
785
+ w_q, w_k, w_v: weights for q, k and v, respectively.
786
+ b_q, b_k, b_v: optional biases for q, k and v, respectively.
787
+ Shape:
788
+ Inputs:
789
+ - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
790
+ number of leading dimensions.
791
+ - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
792
+ number of leading dimensions.
793
+ - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
794
+ number of leading dimensions.
795
+ - w_q: :math:`(Eq, Eq)`
796
+ - w_k: :math:`(Eq, Ek)`
797
+ - w_v: :math:`(Eq, Ev)`
798
+ - b_q: :math:`(Eq)`
799
+ - b_k: :math:`(Eq)`
800
+ - b_v: :math:`(Eq)`
801
+ Output: in output triple :math:`(q', k', v')`,
802
+ - q': :math:`[Qdims..., Eq]`
803
+ - k': :math:`[Kdims..., Eq]`
804
+ - v': :math:`[Vdims..., Eq]`
805
+ """
806
+ Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
807
+ assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
808
+ assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
809
+ assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
810
+ assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
811
+ assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
812
+ assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
813
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)