Upload mmalaya_arch.py
Browse files- mmalaya_arch.py +4 -44
mmalaya_arch.py
CHANGED
|
@@ -3,7 +3,7 @@ import re
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from transformers import Blip2Model, Blip2Processor, Blip2Config
|
| 6 |
-
from .mm_utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX,
|
| 7 |
|
| 8 |
|
| 9 |
class BLIP2VisionTower(nn.Module):
|
|
@@ -265,46 +265,6 @@ class MMAlayaMetaForCausalLM(ABC):
|
|
| 265 |
|
| 266 |
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
| 267 |
|
| 268 |
-
def initialize_vision_tokenizer(self,
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
self.resize_token_embeddings(len(tokenizer))
|
| 272 |
-
|
| 273 |
-
if model_args.mm_use_im_start_end:
|
| 274 |
-
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
| 275 |
-
self.resize_token_embeddings(len(tokenizer))
|
| 276 |
-
|
| 277 |
-
if num_new_tokens > 0:
|
| 278 |
-
input_embeddings = self.get_input_embeddings().weight.data
|
| 279 |
-
output_embeddings = self.get_output_embeddings().weight.data
|
| 280 |
-
|
| 281 |
-
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
| 282 |
-
dim=0, keepdim=True)
|
| 283 |
-
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 284 |
-
dim=0, keepdim=True)
|
| 285 |
-
|
| 286 |
-
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 287 |
-
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 288 |
-
|
| 289 |
-
if model_args.tune_mm_mlp_adapter:
|
| 290 |
-
for p in self.get_input_embeddings().parameters():
|
| 291 |
-
p.requires_grad = True
|
| 292 |
-
for p in self.get_output_embeddings().parameters():
|
| 293 |
-
p.requires_grad = False
|
| 294 |
-
|
| 295 |
-
if model_args.pretrain_mm_mlp_adapter:
|
| 296 |
-
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
| 297 |
-
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
| 298 |
-
assert num_new_tokens == 2
|
| 299 |
-
if input_embeddings.shape == embed_tokens_weight.shape:
|
| 300 |
-
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
| 301 |
-
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
| 302 |
-
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
| 303 |
-
else:
|
| 304 |
-
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
| 305 |
-
elif model_args.mm_use_im_patch_token:
|
| 306 |
-
if model_args.tune_mm_mlp_adapter:
|
| 307 |
-
for p in self.get_input_embeddings().parameters():
|
| 308 |
-
p.requires_grad = False
|
| 309 |
-
for p in self.get_output_embeddings().parameters():
|
| 310 |
-
p.requires_grad = False
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from transformers import Blip2Model, Blip2Processor, Blip2Config
|
| 6 |
+
from .mm_utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
| 7 |
|
| 8 |
|
| 9 |
class BLIP2VisionTower(nn.Module):
|
|
|
|
| 265 |
|
| 266 |
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
| 267 |
|
| 268 |
+
def initialize_vision_tokenizer(self, tokenizer):
|
| 269 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True)
|
| 270 |
+
self.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|