PyTorch
English
internlm2
code
custom_code
AkashahS commited on
Commit
6918706
·
verified ·
1 Parent(s): c4bccce

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +3 -3
  2. geopixel.py +411 -0
  3. pytorch_model.bin.index.json +2 -2
config.json CHANGED
@@ -1,13 +1,13 @@
1
  {
2
- "_name_or_path": "./model/IXC",
3
  "architectures": [
4
- "GeoPixForCausalLM"
5
  ],
6
  "attn_implementation": "flash_attention_2",
7
  "auto_map": {
8
  "AutoConfig": "configuration_internlm_xcomposer2.InternLMXcomposer2Config",
9
  "AutoModel": "modeling_internlm_xcomposer2.InternLMXComposer2ForCausalLM",
10
- "AutoModelForCausalLM": "geopix.GeoPixForCausalLM"
11
  },
12
  "bias": false,
13
  "bos_token_id": 1,
 
1
  {
2
+ "_name_or_path": "AkashahS/GeoPixel-7B",
3
  "architectures": [
4
+ "GeoPixelForCausalLM"
5
  ],
6
  "attn_implementation": "flash_attention_2",
7
  "auto_map": {
8
  "AutoConfig": "configuration_internlm_xcomposer2.InternLMXcomposer2Config",
9
  "AutoModel": "modeling_internlm_xcomposer2.InternLMXComposer2ForCausalLM",
10
+ "AutoModelForCausalLM": "geopixel.GeoPixelForCausalLM"
11
  },
12
  "bias": false,
13
  "bos_token_id": 1,
geopixel.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ import torch.nn.functional as F
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ from model.IXC.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM
12
+ from model.IXC.modeling_internlm2 import InternLM2Model
13
+ from model.sam2.build_sam import build_sam2_hf
14
+ from model.sam2.utils.transforms import SAM2Transforms
15
+ try:
16
+ from transformers.generation.streamers import BaseStreamer
17
+ except: # noqa # pylint: disable=bare-except
18
+ BaseStreamer = None
19
+
20
+
21
+ def dice_loss(
22
+ inputs: torch.Tensor,
23
+ targets: torch.Tensor,
24
+ num_masks: float,
25
+ scale=1000, # 100000.0,
26
+ eps=1e-6,
27
+ ):
28
+ """
29
+ Compute the DICE loss, similar to generalized IOU for masks
30
+ Args:
31
+ inputs: A float tensor of arbitrary shape.
32
+ The predictions for each example.
33
+ targets: A float tensor with the same shape as inputs. Stores the binary
34
+ classification label for each element in inputs
35
+ (0 for the negative class and 1 for the positive class).
36
+ """
37
+ inputs = inputs.sigmoid()
38
+ inputs = inputs.flatten(1, 2)
39
+ targets = targets.flatten(1, 2)
40
+ numerator = 2 * (inputs / scale * targets).sum(-1)
41
+ denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
42
+ loss = 1 - (numerator + eps) / (denominator + eps)
43
+ loss = loss.sum() / (num_masks + 1e-8)
44
+ return loss
45
+
46
+
47
+ def sigmoid_ce_loss(
48
+ inputs: torch.Tensor,
49
+ targets: torch.Tensor,
50
+ num_masks: float,
51
+ ):
52
+ """
53
+ Args:
54
+ inputs: A float tensor of arbitrary shape.
55
+ The predictions for each example.
56
+ targets: A float tensor with the same shape as inputs. Stores the binary
57
+ classification label for each element in inputs
58
+ (0 for the negative class and 1 for the positive class).
59
+ Returns:
60
+ Loss tensor
61
+ """
62
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
63
+ loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
64
+ return loss
65
+
66
+
67
+ class GeoPixelMetaModel:
68
+ def __init__(
69
+ self,
70
+ config,
71
+ **kwargs,
72
+ ):
73
+ super(GeoPixelMetaModel, self).__init__(config)
74
+ self.config = config
75
+ self.config.train_mask_decoder = getattr(self.config, "train_mask_decoder", kwargs.get("train_mask_decoder", False))
76
+ self.config.out_dim = getattr(self.config, "out_dim", kwargs.get("out_dim", 256))
77
+ self.vision_pretrained = kwargs.get("vision_pretrained", None)
78
+ self.initialize_geopixel_modules(self.config)
79
+
80
+ def initialize_geopixel_modules(self, config):
81
+ # grounding vision model
82
+ self.visual_model = build_sam2_hf(self.vision_pretrained)
83
+
84
+ self._transform = SAM2Transforms(
85
+ resolution=self.visual_model.image_size,
86
+ mask_threshold=0.0,
87
+ max_hole_area=0.0,
88
+ max_sprinkle_area=0.0,
89
+ )
90
+ # Spatial dim for backbone feature maps
91
+ self._bb_feat_sizes = [
92
+ (256, 256),
93
+ (128, 128),
94
+ (64, 64),
95
+ ]
96
+ for param in self.visual_model.parameters():
97
+ param.requires_grad = False
98
+ if config.train_mask_decoder:
99
+ self.visual_model.sam_mask_decoder.train()
100
+ for param in self.visual_model.sam_mask_decoder.parameters():
101
+ param.requires_grad = True
102
+
103
+ # text projection layer
104
+ in_dim = config.hidden_size
105
+ out_dim = config.out_dim
106
+ text_projection_layers = [
107
+ nn.Linear(in_dim, in_dim),
108
+ nn.ReLU(inplace=True),
109
+ nn.Linear(in_dim, out_dim),
110
+ nn.Dropout(0.0),
111
+ ]
112
+ self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_projection_layers)])
113
+ self.text_hidden_fcs.train()
114
+ for param in self.text_hidden_fcs.parameters():
115
+ param.requires_grad = True
116
+
117
+
118
+ class GeoPixelModel(GeoPixelMetaModel, InternLM2Model):
119
+ def __init__(
120
+ self,
121
+ config,
122
+ **kwargs,
123
+ ):
124
+ super(GeoPixelModel, self).__init__(config, **kwargs)
125
+ self.config.use_cache = False
126
+
127
+
128
+ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
129
+ def __init__(self,config,**kwargs,):
130
+
131
+ self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
132
+ self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
133
+ self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
134
+ self.seg_token_idx = kwargs.pop("seg_token_idx")
135
+
136
+ super().__init__(config)
137
+ self.model = GeoPixelModel(config, **kwargs)
138
+ self.vocab_size = config.vocab_size
139
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
140
+ self.post_init()
141
+
142
+ def encode_g_img(self, image):
143
+ """
144
+ Calculates the image embeddings for the provided image
145
+ Arguments:
146
+ image (np.ndarray or str)
147
+ """
148
+ if image is None:
149
+ return None
150
+ if isinstance(image, str):
151
+ _, ext = os.path.splitext(image)
152
+ if ext.lower() in {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp','.tif'}:
153
+ image = Image.open(image)
154
+ w, h = image.size
155
+ _orig_hw = [(h, w)]
156
+ else:
157
+ print ('Unknow input format', image)
158
+ return None
159
+ else:
160
+ assert isinstance(image, torch.Tensor)
161
+ _orig_hw = [image.shape[:2]]
162
+ image = self.model._transform(image)
163
+ image = image[None, ...].to(self.device)
164
+ assert ( len(image.shape) == 4 and image.shape[1] == 3), f"image must be of size 1x3xHxW, got {image.shape}"
165
+ features = self.get_visual_embs(image)
166
+ return features,_orig_hw
167
+
168
+ def get_visual_embs(self, img_batch: torch.FloatTensor):
169
+ with torch.no_grad():
170
+ torch.cuda.empty_cache()
171
+ img_batch = img_batch.to(self.device)
172
+ batch_size = img_batch.shape[0]
173
+ assert (
174
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
175
+ ), f"grounding_img_batch must be of size Bx3xHxW, got {img_batch.shape}"
176
+ backbone_out = self.model.visual_model.forward_image(img_batch)
177
+ _, vision_feats, _, _ = self.model.visual_model._prepare_backbone_features(backbone_out)
178
+ if self.model.visual_model.directly_add_no_mem_embed:
179
+ vision_feats[-1] = vision_feats[-1] + self.model.visual_model.no_mem_embed
180
+ feats = [
181
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
182
+ for feat, feat_size in zip(vision_feats[::-1], self.model._bb_feat_sizes[::-1])
183
+ ][::-1]
184
+ features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
185
+ return features
186
+
187
+ def forward(self, **kwargs):
188
+ return super().forward(**kwargs) if "past_key_values" in kwargs else self.model_forward(**kwargs)
189
+
190
+ def model_forward(
191
+ self,
192
+ inference: bool = False,
193
+ **kwargs,
194
+ ):
195
+ samples = kwargs.get('samples', None)
196
+ if samples and samples['data_type'][0] == 'grounding':
197
+ kwargs['output_hidden_states'] = True
198
+ torch.cuda.empty_cache()
199
+ outputs = super().forward(**kwargs)
200
+
201
+ if inference:
202
+ assert len(samples['text_input']) == 1 and len(samples['image'][0]) == 1 #single image and single query
203
+ output_hidden_states = [outputs.hidden_states]
204
+ outputs = None
205
+ else:
206
+ output_hidden_states = outputs.hidden_states
207
+
208
+ hidden_states = []
209
+ assert len(self.model.text_hidden_fcs) == 1
210
+ hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))
211
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
212
+
213
+ seg_token_mask = outputs.seg_token_mask
214
+ pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)]
215
+ image_g_batch = torch.cat(samples['image_g'][0],dim = 0)
216
+ image_g_features = self.get_visual_embs(image_g_batch)
217
+ ori_hw = samples['ori_hw'][0]
218
+ all_pred_masks = []
219
+ for i in range(len(pred_embeddings)): #(bs,)
220
+ if (pred_embeddings[i].numel()== 0):
221
+ pred_masks.append([])
222
+ continue
223
+ (sparse_embeddings, dense_embeddings,) = self.model.visual_model.sam_prompt_encoder(
224
+ points=None,
225
+ boxes=None,
226
+ masks=None,
227
+ text_embeds=pred_embeddings[i].unsqueeze(1),
228
+ )
229
+ batch_mode = (pred_embeddings[i].shape[0]>1)
230
+ high_res_features = [
231
+ feat_level[i].unsqueeze(0)
232
+ for feat_level in image_g_features["high_res_feats"]
233
+ ]
234
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
235
+ image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16)
236
+ low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder(
237
+ image_embeddings=image_g_embeds,
238
+ image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(),
239
+ sparse_prompt_embeddings=sparse_embeddings,
240
+ dense_prompt_embeddings=dense_embeddings,
241
+ repeat_image=batch_mode,
242
+ multimask_output=False,
243
+ high_res_features=high_res_features,
244
+ )
245
+ pred_masks = self.model._transform.postprocess_masks(
246
+ low_res_masks,
247
+ ori_hw[i],
248
+ )
249
+
250
+ # pred_masks = pred_masks.squeeze(0)
251
+ # all_pred_masks.append(pred_masks)
252
+ all_pred_masks.append(pred_masks[:, 0])
253
+
254
+
255
+ model_output = outputs
256
+ gt_masks = samples['masks'][0]
257
+ pred_masks = all_pred_masks
258
+
259
+ if inference:
260
+ return {
261
+ "pred_masks": pred_masks,
262
+ "gt_masks": gt_masks,
263
+ }
264
+
265
+ ce_loss = model_output.loss
266
+ ce_loss = ce_loss * self.ce_loss_weight
267
+ mask_bce_loss = 0
268
+ mask_dice_loss = 0
269
+ num_masks = 0
270
+
271
+ for batch_idx in range(len(pred_masks)): # for every image
272
+ cur_gt_masks = torch.stack(
273
+ [
274
+ torch.from_numpy(gt_mask).to(dtype=pred_masks[batch_idx].dtype, device=pred_masks[batch_idx].device)
275
+ for gt_mask in gt_masks[batch_idx]
276
+ ],
277
+ dim=0
278
+ ) # expected (bs,H,W)
279
+ cur_pred_masks = pred_masks[batch_idx]
280
+ assert (
281
+ cur_gt_masks.shape[0] == cur_pred_masks.shape[0]
282
+ ), "gt_masks.shape: {}, pred_masks.shape: {}".format(
283
+ cur_gt_masks.shape, cur_pred_masks.shape
284
+ )
285
+ mask_bce_loss += (
286
+ sigmoid_ce_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0])
287
+ * cur_gt_masks.shape[0]
288
+ )
289
+ mask_dice_loss += (
290
+ dice_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0])
291
+ * cur_gt_masks.shape[0]
292
+ )
293
+ num_masks += cur_gt_masks.shape[0]
294
+
295
+ mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
296
+ mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
297
+ mask_loss = mask_bce_loss + mask_dice_loss
298
+
299
+ loss = ce_loss + mask_loss
300
+ outputs = CausalLMOutputWithPast(
301
+ loss=loss,
302
+ logits=model_output.logits,
303
+ past_key_values=model_output.past_key_values,
304
+ hidden_states=output_hidden_states,
305
+ attentions=model_output.attentions,
306
+ )
307
+ outputs.ce_loss = ce_loss
308
+ outputs.mask_bce_loss = mask_bce_loss
309
+ outputs.mask_dice_loss = mask_dice_loss
310
+ outputs.mask_loss = mask_loss
311
+ else:
312
+ outputs = super().forward(**kwargs)
313
+ return outputs
314
+
315
+ def evaluate(
316
+ self,
317
+ tokenizer,
318
+ query: str,
319
+ images: List[Tuple[str, str]] = [],
320
+ hd_num: int = 9,
321
+ history: List[Tuple[str, str]] = [],
322
+ max_new_tokens: int = 1024,
323
+ **kwargs,
324
+ ):
325
+ with torch.no_grad():
326
+ inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num)
327
+ print(im_mask.sum().item())
328
+ inputs = {
329
+ k: v.to(self.device)
330
+ for k, v in inputs.items() if torch.is_tensor(v)
331
+ }
332
+ # print(len(inputs['inputs_embeds'][0]))
333
+ eos_token_id = [
334
+ tokenizer.eos_token_id,
335
+ #tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
336
+ ]
337
+ all_pred_masks = []
338
+ outputs = self.generate(
339
+ **inputs,
340
+ max_new_tokens=max_new_tokens,
341
+ im_mask=im_mask,
342
+ input_ids = None,
343
+ streamer= None,
344
+ num_beams=1,
345
+ do_sample=False,
346
+ temperature=1.0,
347
+ top_p= 1.0,
348
+ top_k = 0,
349
+ eos_token_id=eos_token_id,
350
+ repetition_penalty=1.0,
351
+ infer_mode = 'base',
352
+ output_hidden_states=True,
353
+ return_dict_in_generate=True,
354
+ **kwargs,
355
+ )
356
+ output_ids = outputs['sequences']
357
+ response = tokenizer.decode(output_ids[0].cpu().tolist(), skip_special_tokens=True)
358
+ response = response.replace("[UNUSED_TOKEN_145]","")
359
+ history = history + [(query, response)]
360
+ if len(images)==1 and isinstance(images[0], str):
361
+ output_hidden_states = outputs.hidden_states[-1]
362
+ seg_token_mask = output_ids[:, 1:-1] == self.seg_token_idx
363
+ inputs_embeds_len = inputs['inputs_embeds'].size(1)
364
+ seg_token_mask = torch.cat(
365
+ [
366
+ torch.zeros((seg_token_mask.shape[0], inputs_embeds_len)).bool().cuda(),
367
+ seg_token_mask,
368
+ ],
369
+ dim=1,
370
+ )
371
+ hidden_states = []
372
+ assert len(self.model.text_hidden_fcs) == 1
373
+ hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))
374
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
375
+ pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)]
376
+ image_g_features, ori_hw = self.encode_g_img(images[0])
377
+
378
+ for i in range(len(pred_embeddings)):
379
+ if (pred_embeddings[i].numel()== 0):
380
+ all_pred_masks.append([])
381
+ continue
382
+ (sparse_embeddings,dense_embeddings,) = self.model.visual_model.sam_prompt_encoder(
383
+ points=None,
384
+ boxes=None,
385
+ masks=None,
386
+ text_embeds=pred_embeddings[i].unsqueeze(1),
387
+ )
388
+ batch_mode = (pred_embeddings[i].shape[0]>1)
389
+ high_res_features = [
390
+ feat_level[i].unsqueeze(0)
391
+ for feat_level in image_g_features["high_res_feats"]
392
+ ]
393
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
394
+ image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16)
395
+
396
+ low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder(
397
+ image_embeddings=image_g_embeds,
398
+ image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(),
399
+ sparse_prompt_embeddings=sparse_embeddings,
400
+ dense_prompt_embeddings=dense_embeddings,
401
+ repeat_image=batch_mode,
402
+ multimask_output=False,
403
+ high_res_features=high_res_features,
404
+ )
405
+ pred_masks = self.model._transform.postprocess_masks(
406
+ low_res_masks,
407
+ ori_hw[i],
408
+ )
409
+ all_pred_masks.append(pred_masks[:, 0])
410
+
411
+ return response, all_pred_masks
pytorch_model.bin.index.json CHANGED
@@ -2218,7 +2218,7 @@
2218
  "model.visual_model.memory_attention.norm.weight": "pytorch_model-00003-of-00003.bin",
2219
  "model.visual_model.memory_encoder.fuser.layers.0.dwconv.bias": "pytorch_model-00003-of-00003.bin",
2220
  "model.visual_model.memory_encoder.fuser.layers.0.dwconv.weight": "pytorch_model-00003-of-00003.bin",
2221
- "model.visual_model.memory_encoder.fuser.layers.0.gamma": "pytorch_model-00003-of-00003.bin",
2222
  "model.visual_model.memory_encoder.fuser.layers.0.norm.bias": "pytorch_model-00003-of-00003.bin",
2223
  "model.visual_model.memory_encoder.fuser.layers.0.norm.weight": "pytorch_model-00003-of-00003.bin",
2224
  "model.visual_model.memory_encoder.fuser.layers.0.pwconv1.bias": "pytorch_model-00003-of-00003.bin",
@@ -2227,7 +2227,7 @@
2227
  "model.visual_model.memory_encoder.fuser.layers.0.pwconv2.weight": "pytorch_model-00003-of-00003.bin",
2228
  "model.visual_model.memory_encoder.fuser.layers.1.dwconv.bias": "pytorch_model-00003-of-00003.bin",
2229
  "model.visual_model.memory_encoder.fuser.layers.1.dwconv.weight": "pytorch_model-00003-of-00003.bin",
2230
- "model.visual_model.memory_encoder.fuser.layers.1.gamma": "pytorch_model-00003-of-00003.bin",
2231
  "model.visual_model.memory_encoder.fuser.layers.1.norm.bias": "pytorch_model-00003-of-00003.bin",
2232
  "model.visual_model.memory_encoder.fuser.layers.1.norm.weight": "pytorch_model-00003-of-00003.bin",
2233
  "model.visual_model.memory_encoder.fuser.layers.1.pwconv1.bias": "pytorch_model-00003-of-00003.bin",
 
2218
  "model.visual_model.memory_attention.norm.weight": "pytorch_model-00003-of-00003.bin",
2219
  "model.visual_model.memory_encoder.fuser.layers.0.dwconv.bias": "pytorch_model-00003-of-00003.bin",
2220
  "model.visual_model.memory_encoder.fuser.layers.0.dwconv.weight": "pytorch_model-00003-of-00003.bin",
2221
+ "model.visual_model.memory_encoder.fuser.layers.0.weight": "pytorch_model-00003-of-00003.bin",
2222
  "model.visual_model.memory_encoder.fuser.layers.0.norm.bias": "pytorch_model-00003-of-00003.bin",
2223
  "model.visual_model.memory_encoder.fuser.layers.0.norm.weight": "pytorch_model-00003-of-00003.bin",
2224
  "model.visual_model.memory_encoder.fuser.layers.0.pwconv1.bias": "pytorch_model-00003-of-00003.bin",
 
2227
  "model.visual_model.memory_encoder.fuser.layers.0.pwconv2.weight": "pytorch_model-00003-of-00003.bin",
2228
  "model.visual_model.memory_encoder.fuser.layers.1.dwconv.bias": "pytorch_model-00003-of-00003.bin",
2229
  "model.visual_model.memory_encoder.fuser.layers.1.dwconv.weight": "pytorch_model-00003-of-00003.bin",
2230
+ "model.visual_model.memory_encoder.fuser.layers.1.weight": "pytorch_model-00003-of-00003.bin",
2231
  "model.visual_model.memory_encoder.fuser.layers.1.norm.bias": "pytorch_model-00003-of-00003.bin",
2232
  "model.visual_model.memory_encoder.fuser.layers.1.norm.weight": "pytorch_model-00003-of-00003.bin",
2233
  "model.visual_model.memory_encoder.fuser.layers.1.pwconv1.bias": "pytorch_model-00003-of-00003.bin",