Update seagull/model/seagull_arch.py
Browse files- seagull/model/seagull_arch.py +11 -11
seagull/model/seagull_arch.py
CHANGED
|
@@ -92,7 +92,7 @@ class SeagullMetaForCausalLM(ABC):
|
|
| 92 |
image_features, image_features_dict = self.encode_images(images)
|
| 93 |
|
| 94 |
|
| 95 |
-
|
| 96 |
|
| 97 |
new_input_embeds = []
|
| 98 |
new_labels = [] if labels is not None else None
|
|
@@ -151,10 +151,10 @@ class SeagullMetaForCausalLM(ABC):
|
|
| 151 |
_l = 0
|
| 152 |
for i, idx in enumerate(mask_idx):
|
| 153 |
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
|
| 154 |
-
##
|
| 155 |
-
cur_new_input_embeds.append(
|
| 156 |
-
##
|
| 157 |
-
cur_new_input_embeds.append(
|
| 158 |
if labels is not None:
|
| 159 |
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
| 160 |
_l = idx[0]+2
|
|
@@ -164,16 +164,16 @@ class SeagullMetaForCausalLM(ABC):
|
|
| 164 |
else:
|
| 165 |
|
| 166 |
mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
|
| 167 |
-
assert len(mask_idx) == len(
|
| 168 |
|
| 169 |
_l = 0
|
| 170 |
for i, idx in enumerate(mask_idx):
|
| 171 |
cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
|
| 172 |
cur_new_input_embeds.append(cur_raw_new_input_embeds)
|
| 173 |
-
##
|
| 174 |
-
cur_new_input_embeds.append(
|
| 175 |
-
##
|
| 176 |
-
cur_new_input_embeds.append(
|
| 177 |
|
| 178 |
if labels is not None:
|
| 179 |
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
|
@@ -235,7 +235,7 @@ class SeagullMetaForCausalLM(ABC):
|
|
| 235 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 236 |
self.resize_token_embeddings(len(tokenizer))
|
| 237 |
|
| 238 |
-
mask_tokens = ['<global>', '<
|
| 239 |
num_new_tokens = tokenizer.add_tokens(mask_tokens, special_tokens=True)
|
| 240 |
|
| 241 |
if model_args.mm_use_im_start_end:
|
|
|
|
| 92 |
image_features, image_features_dict = self.encode_images(images)
|
| 93 |
|
| 94 |
|
| 95 |
+
global_features_, local_features_ = self.mask_extractor(image_features_dict, masks, cropped_img=cropped_img)
|
| 96 |
|
| 97 |
new_input_embeds = []
|
| 98 |
new_labels = [] if labels is not None else None
|
|
|
|
| 151 |
_l = 0
|
| 152 |
for i, idx in enumerate(mask_idx):
|
| 153 |
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
|
| 154 |
+
## global
|
| 155 |
+
cur_new_input_embeds.append(global_features_[batch_idx][i:i+1].detach())
|
| 156 |
+
## local
|
| 157 |
+
cur_new_input_embeds.append(local_features_[batch_idx][i:i+1].detach())
|
| 158 |
if labels is not None:
|
| 159 |
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
| 160 |
_l = idx[0]+2
|
|
|
|
| 164 |
else:
|
| 165 |
|
| 166 |
mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
|
| 167 |
+
assert len(mask_idx) == len(global_features_[batch_idx]), "mask num not equal to mask feats"
|
| 168 |
|
| 169 |
_l = 0
|
| 170 |
for i, idx in enumerate(mask_idx):
|
| 171 |
cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
|
| 172 |
cur_new_input_embeds.append(cur_raw_new_input_embeds)
|
| 173 |
+
## global
|
| 174 |
+
cur_new_input_embeds.append(global_features_[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
|
| 175 |
+
## local
|
| 176 |
+
cur_new_input_embeds.append(local_features_[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
|
| 177 |
|
| 178 |
if labels is not None:
|
| 179 |
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
|
|
|
| 235 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 236 |
self.resize_token_embeddings(len(tokenizer))
|
| 237 |
|
| 238 |
+
mask_tokens = ['<global>', '<local>']
|
| 239 |
num_new_tokens = tokenizer.add_tokens(mask_tokens, special_tokens=True)
|
| 240 |
|
| 241 |
if model_args.mm_use_im_start_end:
|