PyTorch
English
internlm2
code
custom_code
AkashahS commited on
Commit
1af05d0
·
verified ·
1 Parent(s): 21dddf0

Upload folder using huggingface_hub

Browse files
added_tokens.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</p>": 92552,
3
+ "<p>": 92551,
4
+ "<|action_end|>": 92547,
5
+ "<|action_start|>": 92546,
6
+ "<|im_end|>": 92545,
7
+ "<|im_start|>": 92544,
8
+ "<|interpreter|>": 92548,
9
+ "<|plugin|>": 92549,
10
+ "[SEG]": 92550
11
+ }
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "MBZUAI/GeoPixel-7B-RES",
3
+ "architectures": [
4
+ "GeoPixelForCausalLM"
5
+ ],
6
+ "attn_implementation": "flash_attention_2",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_internlm_xcomposer2.InternLMXcomposer2Config",
9
+ "AutoModel": "modeling_internlm_xcomposer2.InternLMXComposer2ForCausalLM",
10
+ "AutoModelForCausalLM": "geopixel.GeoPixelForCausalLM"
11
+ },
12
+ "bias": false,
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 14336,
19
+ "max_length": 16384,
20
+ "max_position_embeddings": 24576,
21
+ "model_type": "internlm2",
22
+ "num_attention_heads": 32,
23
+ "num_hidden_layers": 32,
24
+ "num_key_value_heads": 8,
25
+ "out_dim": 256,
26
+ "pad_token_id": 0,
27
+ "rms_norm_eps": 1e-05,
28
+ "rope_scaling": {
29
+ "factor": 2.0,
30
+ "type": "dynamic"
31
+ },
32
+ "rope_theta": 1000000,
33
+ "tie_word_embeddings": false,
34
+ "torch_dtype": "bfloat16",
35
+ "train_mask_decoder": true,
36
+ "transformers_version": "4.33.2",
37
+ "use_cache": false,
38
+ "vocab_size": 92553
39
+ }
configuration_internlm_xcomposer2.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ InternLM2 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ class InternLMXcomposer2Config(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
30
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
31
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 32000):
39
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`InternLM2Model`]
41
+ hidden_size (`int`, *optional*, defaults to 4096):
42
+ Dimension of the hidden representations.
43
+ intermediate_size (`int`, *optional*, defaults to 11008):
44
+ Dimension of the MLP representations.
45
+ num_hidden_layers (`int`, *optional*, defaults to 32):
46
+ Number of hidden layers in the Transformer encoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 32):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ num_key_value_heads (`int`, *optional*):
50
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
51
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
52
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
53
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
54
+ by meanpooling all the original heads within that group. For more details checkout [this
55
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
56
+ `num_attention_heads`.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
58
+ The non-linear activation function (function or string) in the decoder.
59
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
60
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
61
+ just in case (e.g., 512 or 1024 or 2048).
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
70
+ Whether to tie weight embeddings
71
+ Example:
72
+
73
+ """
74
+ model_type = "internlm2"
75
+ _auto_class = "AutoConfig"
76
+
77
+ def __init__( # pylint: disable=W0102
78
+ self,
79
+ vocab_size=103168,
80
+ hidden_size=4096,
81
+ intermediate_size=11008,
82
+ num_hidden_layers=32,
83
+ num_attention_heads=32,
84
+ num_key_value_heads=None,
85
+ hidden_act="silu",
86
+ max_position_embeddings=2048,
87
+ initializer_range=0.02,
88
+ rms_norm_eps=1e-6,
89
+ use_cache=False,
90
+ pad_token_id=0,
91
+ bos_token_id=1,
92
+ eos_token_id=2,
93
+ tie_word_embeddings=False,
94
+ bias=True,
95
+ rope_theta=10000,
96
+ rope_scaling=None,
97
+ attn_implementation="flash_attention_2",
98
+ **kwargs,
99
+ ):
100
+ self.vocab_size = vocab_size
101
+ self.max_position_embeddings = max_position_embeddings
102
+ self.hidden_size = hidden_size
103
+ self.intermediate_size = intermediate_size
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.num_attention_heads = num_attention_heads
106
+ self.bias = bias
107
+
108
+ if num_key_value_heads is None:
109
+ num_key_value_heads = num_attention_heads
110
+ self.num_key_value_heads = num_key_value_heads
111
+
112
+ self.hidden_act = hidden_act
113
+ self.initializer_range = initializer_range
114
+ self.rms_norm_eps = rms_norm_eps
115
+ self.use_cache = use_cache
116
+ self.rope_theta = rope_theta
117
+ self.rope_scaling = rope_scaling
118
+ self._rope_scaling_validation()
119
+
120
+ self.attn_implementation = attn_implementation
121
+ if self.attn_implementation is None:
122
+ self.attn_implementation = "flash_attention_2"
123
+ super().__init__(
124
+ pad_token_id=pad_token_id,
125
+ bos_token_id=bos_token_id,
126
+ eos_token_id=eos_token_id,
127
+ tie_word_embeddings=tie_word_embeddings,
128
+ **kwargs,
129
+ )
130
+
131
+ def _rope_scaling_validation(self):
132
+ """
133
+ Validate the `rope_scaling` configuration.
134
+ """
135
+ if self.rope_scaling is None:
136
+ return
137
+
138
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
139
+ raise ValueError(
140
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
141
+ f"got {self.rope_scaling}"
142
+ )
143
+ rope_scaling_type = self.rope_scaling.get("type", None)
144
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
145
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
146
+ raise ValueError(
147
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
148
+ )
149
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
150
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "max_length": 16384,
6
+ "pad_token_id": 2,
7
+ "transformers_version": "4.33.2",
8
+ "use_cache": false
9
+ }
geopixel.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ import torch.nn.functional as F
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ from model.IXC.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM
12
+ from model.IXC.modeling_internlm2 import InternLM2Model
13
+ from model.sam2.build_sam import build_sam2_hf
14
+ from model.sam2.utils.transforms import SAM2Transforms
15
+ from transformers import TextStreamer
16
+ try:
17
+ from transformers.generation.streamers import BaseStreamer
18
+ except: # noqa # pylint: disable=bare-except
19
+ BaseStreamer = None
20
+
21
+
22
+ def dice_loss(
23
+ inputs: torch.Tensor,
24
+ targets: torch.Tensor,
25
+ num_masks: float,
26
+ scale=1000, # 100000.0,
27
+ eps=1e-6,
28
+ ):
29
+ """
30
+ Compute the DICE loss, similar to generalized IOU for masks
31
+ Args:
32
+ inputs: A float tensor of arbitrary shape.
33
+ The predictions for each example.
34
+ targets: A float tensor with the same shape as inputs. Stores the binary
35
+ classification label for each element in inputs
36
+ (0 for the negative class and 1 for the positive class).
37
+ """
38
+ inputs = inputs.sigmoid()
39
+ inputs = inputs.flatten(1, 2)
40
+ targets = targets.flatten(1, 2)
41
+ numerator = 2 * (inputs / scale * targets).sum(-1)
42
+ denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
43
+ loss = 1 - (numerator + eps) / (denominator + eps)
44
+ loss = loss.sum() / (num_masks + 1e-8)
45
+ return loss
46
+
47
+
48
+ def sigmoid_ce_loss(
49
+ inputs: torch.Tensor,
50
+ targets: torch.Tensor,
51
+ num_masks: float,
52
+ ):
53
+ """
54
+ Args:
55
+ inputs: A float tensor of arbitrary shape.
56
+ The predictions for each example.
57
+ targets: A float tensor with the same shape as inputs. Stores the binary
58
+ classification label for each element in inputs
59
+ (0 for the negative class and 1 for the positive class).
60
+ Returns:
61
+ Loss tensor
62
+ """
63
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
64
+ loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
65
+ return loss
66
+
67
+
68
+ class GeoPixelMetaModel:
69
+ def __init__(
70
+ self,
71
+ config,
72
+ **kwargs,
73
+ ):
74
+ super(GeoPixelMetaModel, self).__init__(config)
75
+ self.config = config
76
+ self.config.train_mask_decoder = getattr(self.config, "train_mask_decoder", kwargs.get("train_mask_decoder", False))
77
+ self.config.out_dim = getattr(self.config, "out_dim", kwargs.get("out_dim", 256))
78
+ self.vision_pretrained = kwargs.get("vision_pretrained", None)
79
+ self.initialize_geopixel_modules(self.config)
80
+
81
+ def initialize_geopixel_modules(self, config):
82
+ # grounding vision model
83
+ self.visual_model = build_sam2_hf(self.vision_pretrained)
84
+
85
+ self._transform = SAM2Transforms(
86
+ resolution=self.visual_model.image_size,
87
+ mask_threshold=0.0,
88
+ max_hole_area=0.0,
89
+ max_sprinkle_area=0.0,
90
+ )
91
+ # Spatial dim for backbone feature maps
92
+ self._bb_feat_sizes = [
93
+ (256, 256),
94
+ (128, 128),
95
+ (64, 64),
96
+ ]
97
+
98
+ for param in self.visual_model.parameters():
99
+ param.requires_grad = False
100
+
101
+ if config.train_mask_decoder:
102
+ self.visual_model.sam_mask_decoder.train()
103
+ for param in self.visual_model.sam_mask_decoder.parameters():
104
+ param.requires_grad = True
105
+
106
+ # text projection layer
107
+ in_dim = config.hidden_size
108
+ out_dim = config.out_dim
109
+ text_projection_layers = [
110
+ nn.Linear(in_dim, in_dim),
111
+ nn.ReLU(inplace=True),
112
+ nn.Linear(in_dim, out_dim),
113
+ nn.Dropout(0.0),
114
+ ]
115
+ self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_projection_layers)])
116
+ self.text_hidden_fcs.train()
117
+ for param in self.text_hidden_fcs.parameters():
118
+ param.requires_grad = True
119
+
120
+
121
+ class GeoPixelModel(GeoPixelMetaModel, InternLM2Model):
122
+ def __init__(
123
+ self,
124
+ config,
125
+ **kwargs,
126
+ ):
127
+ super(GeoPixelModel, self).__init__(config, **kwargs)
128
+ self.config.use_cache = False
129
+
130
+
131
+ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
132
+ def __init__(self,config,**kwargs,):
133
+
134
+ self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
135
+ self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
136
+ self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
137
+ self.seg_token_idx = kwargs.pop("seg_token_idx")
138
+
139
+ super().__init__(config)
140
+ self.model = GeoPixelModel(config, **kwargs)
141
+ self.vocab_size = config.vocab_size
142
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
143
+ self.post_init()
144
+
145
+ def encode_g_img(self, image):
146
+ """
147
+ Calculates the image embeddings for the provided image
148
+ Arguments:
149
+ image (np.ndarray or str)
150
+ """
151
+ if image is None:
152
+ return None
153
+ if isinstance(image, str):
154
+ _, ext = os.path.splitext(image)
155
+ if ext.lower() in {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp','.tif'}:
156
+ image = Image.open(image)
157
+ w, h = image.size
158
+ _orig_hw = [(h, w)]
159
+ else:
160
+ print ('Unknow input format', image)
161
+ return None
162
+ else:
163
+ assert isinstance(image, torch.Tensor)
164
+ _orig_hw = [image.shape[:2]]
165
+ image = self.model._transform(image)
166
+ image = image[None, ...].to(self.device)
167
+ assert ( len(image.shape) == 4 and image.shape[1] == 3), f"image must be of size 1x3xHxW, got {image.shape}"
168
+ features = self.get_visual_embs(image)
169
+ return features,_orig_hw
170
+
171
+ def get_visual_embs(self, img_batch: torch.FloatTensor):
172
+ with torch.no_grad():
173
+ torch.cuda.empty_cache()
174
+ img_batch = img_batch.to(self.device)
175
+ batch_size = img_batch.shape[0]
176
+ assert (
177
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
178
+ ), f"grounding_img_batch must be of size Bx3xHxW, got {img_batch.shape}"
179
+ backbone_out = self.model.visual_model.forward_image(img_batch)
180
+ _, vision_feats, _, _ = self.model.visual_model._prepare_backbone_features(backbone_out)
181
+ if self.model.visual_model.directly_add_no_mem_embed:
182
+ vision_feats[-1] = vision_feats[-1] + self.model.visual_model.no_mem_embed
183
+ feats = [
184
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
185
+ for feat, feat_size in zip(vision_feats[::-1], self.model._bb_feat_sizes[::-1])
186
+ ][::-1]
187
+ features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
188
+ return features
189
+
190
+ def forward(self, **kwargs):
191
+ return super().forward(**kwargs) if "past_key_values" in kwargs else self.model_forward(**kwargs)
192
+
193
+ def model_forward(
194
+ self,
195
+ inference: bool = False,
196
+ **kwargs,
197
+ ):
198
+ samples = kwargs.get('samples', None)
199
+ if samples and samples['data_type'][0] == 'grounding':
200
+ kwargs['output_hidden_states'] = True
201
+ kwargs['use_cache'] = False
202
+
203
+ torch.cuda.empty_cache()
204
+ outputs = super().forward(**kwargs)
205
+
206
+ if inference:
207
+ assert len(samples['text_input']) == 1 and len(samples['image'][0]) == 1 #single image and single query
208
+ output_hidden_states = [outputs.hidden_states]
209
+ outputs = None
210
+ else:
211
+ output_hidden_states = outputs.hidden_states
212
+
213
+ hidden_states = []
214
+ assert len(self.model.text_hidden_fcs) == 1
215
+ hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))
216
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
217
+
218
+ seg_token_mask = outputs.seg_token_mask
219
+ pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)]
220
+ image_g_batch = torch.cat(samples['image_g'][0],dim = 0)
221
+ image_g_features = self.get_visual_embs(image_g_batch)
222
+ ori_hw = samples['ori_hw'][0]
223
+ all_pred_masks = []
224
+ for i in range(len(pred_embeddings)): #(bs,)
225
+ if (pred_embeddings[i].numel()== 0):
226
+ pred_masks.append([])
227
+ continue
228
+ (sparse_embeddings, dense_embeddings,) = self.model.visual_model.sam_prompt_encoder(
229
+ points=None,
230
+ boxes=None,
231
+ masks=None,
232
+ text_embeds=pred_embeddings[i].unsqueeze(1),
233
+ )
234
+ batch_mode = (pred_embeddings[i].shape[0]>1)
235
+ high_res_features = [
236
+ feat_level[i].unsqueeze(0)
237
+ for feat_level in image_g_features["high_res_feats"]
238
+ ]
239
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
240
+ image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16)
241
+ low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder(
242
+ image_embeddings=image_g_embeds,
243
+ image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(),
244
+ sparse_prompt_embeddings=sparse_embeddings,
245
+ dense_prompt_embeddings=dense_embeddings,
246
+ repeat_image=batch_mode,
247
+ multimask_output=False,
248
+ high_res_features=high_res_features,
249
+ )
250
+ pred_masks = self.model._transform.postprocess_masks(
251
+ low_res_masks,
252
+ ori_hw[i],
253
+ )
254
+ all_pred_masks.append(pred_masks[:, 0])
255
+
256
+
257
+ model_output = outputs
258
+ gt_masks = samples['masks'][0]
259
+ pred_masks = all_pred_masks
260
+
261
+ if inference:
262
+ return {
263
+ "pred_masks": pred_masks,
264
+ "gt_masks": gt_masks,
265
+ }
266
+
267
+ ce_loss = model_output.loss
268
+ ce_loss = ce_loss * self.ce_loss_weight
269
+ mask_bce_loss = 0
270
+ mask_dice_loss = 0
271
+ num_masks = 0
272
+
273
+ for batch_idx in range(len(pred_masks)): # for every image
274
+ cur_gt_masks = torch.stack(
275
+ [
276
+ torch.from_numpy(gt_mask).to(dtype=pred_masks[batch_idx].dtype, device=pred_masks[batch_idx].device)
277
+ for gt_mask in gt_masks[batch_idx]
278
+ ],
279
+ dim=0
280
+ ) # expected (bs,H,W)
281
+ cur_pred_masks = pred_masks[batch_idx]
282
+ assert (
283
+ cur_gt_masks.shape[0] == cur_pred_masks.shape[0]
284
+ ), "gt_masks.shape: {}, pred_masks.shape: {}".format(
285
+ cur_gt_masks.shape, cur_pred_masks.shape
286
+ )
287
+ mask_bce_loss += (
288
+ sigmoid_ce_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0])
289
+ * cur_gt_masks.shape[0]
290
+ )
291
+ mask_dice_loss += (
292
+ dice_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0])
293
+ * cur_gt_masks.shape[0]
294
+ )
295
+ num_masks += cur_gt_masks.shape[0]
296
+
297
+ mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
298
+ mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
299
+ mask_loss = mask_bce_loss + mask_dice_loss
300
+
301
+ loss = ce_loss + mask_loss
302
+ outputs = CausalLMOutputWithPast(
303
+ loss=loss,
304
+ logits=model_output.logits,
305
+ past_key_values=model_output.past_key_values,
306
+ hidden_states=output_hidden_states,
307
+ attentions=model_output.attentions,
308
+ )
309
+ outputs.ce_loss = ce_loss
310
+ outputs.mask_bce_loss = mask_bce_loss
311
+ outputs.mask_dice_loss = mask_dice_loss
312
+ outputs.mask_loss = mask_loss
313
+ else:
314
+ outputs = super().forward(**kwargs)
315
+ return outputs
316
+
317
+ def evaluate(
318
+ self,
319
+ tokenizer,
320
+ query: str,
321
+ images: List[Tuple[str, str]] = [],
322
+ hd_num: int = 9,
323
+ history: List[Tuple[str, str]] = [],
324
+ max_new_tokens: int = 1024,
325
+ stream: bool = False,
326
+ **kwargs,
327
+ ):
328
+ with torch.no_grad():
329
+ inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num)
330
+ inputs = {
331
+ k: v.to(self.device)
332
+ for k, v in inputs.items() if torch.is_tensor(v)
333
+ }
334
+ eos_token_id = [
335
+ tokenizer.eos_token_id,
336
+ #tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
337
+ ]
338
+ all_pred_masks = []
339
+
340
+ if stream:
341
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
342
+ else:
343
+ streamer = None
344
+
345
+ outputs = self.generate(
346
+ **inputs,
347
+ max_new_tokens=max_new_tokens,
348
+ im_mask=im_mask,
349
+ input_ids = None,
350
+ streamer= streamer,
351
+ num_beams=1,
352
+ do_sample=False,
353
+ temperature=1.0,
354
+ top_p= 1.0,
355
+ top_k = 0,
356
+ eos_token_id=eos_token_id,
357
+ repetition_penalty=1.0,
358
+ infer_mode = 'base',
359
+ output_hidden_states=True,
360
+ return_dict_in_generate=True,
361
+ **kwargs,
362
+ )
363
+ output_ids = outputs['sequences']
364
+ response = tokenizer.decode(output_ids[0].cpu().tolist(), skip_special_tokens=True)
365
+ response = response.replace("[UNUSED_TOKEN_145]","")
366
+ history = history + [(query, response)]
367
+ if len(images)==1 and isinstance(images[0], str):
368
+ output_hidden_states = outputs.hidden_states[-1]
369
+ seg_token_mask = output_ids[:, 1:-1] == self.seg_token_idx
370
+ inputs_embeds_len = inputs['inputs_embeds'].size(1)
371
+ seg_token_mask = torch.cat(
372
+ [
373
+ torch.zeros((seg_token_mask.shape[0], inputs_embeds_len)).bool().cuda(),
374
+ seg_token_mask,
375
+ ],
376
+ dim=1,
377
+ )
378
+ hidden_states = []
379
+ assert len(self.model.text_hidden_fcs) == 1
380
+ hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))
381
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
382
+ pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)]
383
+ image_g_features, ori_hw = self.encode_g_img(images[0])
384
+
385
+ for i in range(len(pred_embeddings)):
386
+ if (pred_embeddings[i].numel()== 0):
387
+ all_pred_masks.append([])
388
+ continue
389
+ (sparse_embeddings,dense_embeddings,) = self.model.visual_model.sam_prompt_encoder(
390
+ points=None,
391
+ boxes=None,
392
+ masks=None,
393
+ text_embeds=pred_embeddings[i].unsqueeze(1),
394
+ )
395
+ batch_mode = (pred_embeddings[i].shape[0]>1)
396
+ high_res_features = [
397
+ feat_level[i].unsqueeze(0)
398
+ for feat_level in image_g_features["high_res_feats"]
399
+ ]
400
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
401
+ image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16)
402
+
403
+ low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder(
404
+ image_embeddings=image_g_embeds,
405
+ image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(),
406
+ sparse_prompt_embeddings=sparse_embeddings,
407
+ dense_prompt_embeddings=dense_embeddings,
408
+ repeat_image=batch_mode,
409
+ multimask_output=False,
410
+ high_res_features=high_res_features,
411
+ )
412
+ pred_masks = self.model._transform.postprocess_masks(
413
+ low_res_masks,
414
+ ori_hw[i],
415
+ )
416
+ all_pred_masks.append(pred_masks[:, 0])
417
+
418
+ return response, all_pred_masks
pytorch_model-00001-of-00003.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6eedea61e25c9184b4e2989a3c6d79982a6f4ee263041401cfe443319d04863
3
+ size 9968330657
pytorch_model-00002-of-00003.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0df18331981448f806bbea1932341d233c2dc9f71a888d16586001efd5c189a4
3
+ size 9999750322
pytorch_model-00003-of-00003.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cb744be7e9cc162d6947b63c95cad2370d844b6e5aa76a58a66f61f1bb2d54a
3
+ size 2709063690
pytorch_model.bin.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|action_start|>",
6
+ "<|action_end|>",
7
+ "<|interpreter|>",
8
+ "<|plugin|>"
9
+ ],
10
+ "bos_token": {
11
+ "content": "<s>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ "eos_token": {
18
+ "content": "</s>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": "<unk>",
25
+ "unk_token": {
26
+ "content": "<unk>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ }
32
+ }
tokenization_internlm2.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization classes for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import sentencepiece as spm
24
+ from transformers.tokenization_utils import PreTrainedTokenizer
25
+ from transformers.utils import logging
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {}
32
+
33
+
34
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
35
+ class InternLM2Tokenizer(PreTrainedTokenizer):
36
+ """
37
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
38
+
39
+ Args:
40
+ vocab_file (`str`):
41
+ Path to the vocabulary file.
42
+ """
43
+
44
+ vocab_files_names = VOCAB_FILES_NAMES
45
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+ _auto_class = "AutoTokenizer"
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_file,
52
+ unk_token="<unk>",
53
+ bos_token="<s>",
54
+ eos_token="</s>",
55
+ pad_token="</s>",
56
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
57
+ add_bos_token=True,
58
+ add_eos_token=False,
59
+ decode_with_prefix_space=False,
60
+ clean_up_tokenization_spaces=False,
61
+ **kwargs,
62
+ ):
63
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
64
+ self.vocab_file = vocab_file
65
+ self.add_bos_token = add_bos_token
66
+ self.add_eos_token = add_eos_token
67
+ self.decode_with_prefix_space = decode_with_prefix_space
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(vocab_file)
70
+ self._no_prefix_space_tokens = None
71
+ super().__init__(
72
+ bos_token=bos_token,
73
+ eos_token=eos_token,
74
+ unk_token=unk_token,
75
+ pad_token=pad_token,
76
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
77
+ **kwargs,
78
+ )
79
+
80
+ @property
81
+ def no_prefix_space_tokens(self):
82
+ if self._no_prefix_space_tokens is None:
83
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
84
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
85
+ return self._no_prefix_space_tokens
86
+
87
+ @property
88
+ def vocab_size(self):
89
+ """Returns vocab size"""
90
+ return self.sp_model.get_piece_size()
91
+
92
+ @property
93
+ def bos_token_id(self) -> Optional[int]:
94
+ return self.sp_model.bos_id()
95
+
96
+ @property
97
+ def eos_token_id(self) -> Optional[int]:
98
+ return self.sp_model.eos_id()
99
+
100
+ def get_vocab(self):
101
+ """Returns vocab as a dict"""
102
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
103
+ vocab.update(self.added_tokens_encoder)
104
+ return vocab
105
+
106
+ def _tokenize(self, text):
107
+ """Returns a tokenized string."""
108
+ return self.sp_model.encode(text, out_type=str)
109
+
110
+ def _convert_token_to_id(self, token):
111
+ """Converts a token (str) in an id using the vocab."""
112
+ return self.sp_model.piece_to_id(token)
113
+
114
+ def _convert_id_to_token(self, index):
115
+ """Converts an index (integer) in a token (str) using the vocab."""
116
+ token = self.sp_model.IdToPiece(index)
117
+ return token
118
+
119
+ def _maybe_add_prefix_space(self, tokens, decoded):
120
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
121
+ return " " + decoded
122
+ else:
123
+ return decoded
124
+
125
+ def convert_tokens_to_string(self, tokens):
126
+ """Converts a sequence of tokens (string) in a single string."""
127
+ current_sub_tokens = []
128
+ out_string = ""
129
+ prev_is_special = False
130
+ for token in tokens:
131
+ # make sure that special tokens are not decoded using sentencepiece model
132
+ if token in self.all_special_tokens:
133
+ if not prev_is_special:
134
+ out_string += " "
135
+ out_string += self.sp_model.decode(current_sub_tokens) + token
136
+ prev_is_special = True
137
+ current_sub_tokens = []
138
+ else:
139
+ current_sub_tokens.append(token)
140
+ prev_is_special = False
141
+ out_string += self.sp_model.decode(current_sub_tokens)
142
+ out_string = self.clean_up_tokenization(out_string)
143
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
144
+ return out_string[1:]
145
+
146
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
147
+ """
148
+ Save the vocabulary and special tokens file to a directory.
149
+
150
+ Args:
151
+ save_directory (`str`):
152
+ The directory in which to save the vocabulary.
153
+
154
+ Returns:
155
+ `Tuple(str)`: Paths to the files saved.
156
+ """
157
+ if not os.path.isdir(save_directory):
158
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
159
+ return
160
+ out_vocab_file = os.path.join(
161
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
162
+ )
163
+
164
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
165
+ copyfile(self.vocab_file, out_vocab_file)
166
+ elif not os.path.isfile(self.vocab_file):
167
+ with open(out_vocab_file, "wb") as fi:
168
+ content_spiece_model = self.sp_model.serialized_model_proto()
169
+ fi.write(content_spiece_model)
170
+
171
+ return (out_vocab_file,)
172
+
173
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
174
+ if self.add_bos_token:
175
+ bos_token_ids = [self.bos_token_id]
176
+ else:
177
+ bos_token_ids = []
178
+
179
+ output = bos_token_ids + token_ids_0
180
+
181
+ if token_ids_1 is not None:
182
+ output = output + token_ids_1
183
+
184
+ if self.add_eos_token:
185
+ output = output + [self.eos_token_id]
186
+
187
+ return output
188
+
189
+ def get_special_tokens_mask(
190
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
191
+ ) -> List[int]:
192
+ """
193
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
194
+ special tokens using the tokenizer `prepare_for_model` method.
195
+
196
+ Args:
197
+ token_ids_0 (`List[int]`):
198
+ List of IDs.
199
+ token_ids_1 (`List[int]`, *optional*):
200
+ Optional second list of IDs for sequence pairs.
201
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
202
+ Whether or not the token list is already formatted with special tokens for the model.
203
+
204
+ Returns:
205
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
206
+ """
207
+ if already_has_special_tokens:
208
+ return super().get_special_tokens_mask(
209
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
210
+ )
211
+
212
+ if token_ids_1 is None:
213
+ return [1] + ([0] * len(token_ids_0)) + [1]
214
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
215
+
216
+ def create_token_type_ids_from_sequences(
217
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
218
+ ) -> List[int]:
219
+ """
220
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
221
+ use of token type ids, therefore a list of zeros is returned.
222
+
223
+ Args:
224
+ token_ids_0 (`List[int]`):
225
+ List of IDs.
226
+ token_ids_1 (`List[int]`, *optional*):
227
+ Optional second list of IDs for sequence pairs.
228
+
229
+ Returns:
230
+ `List[int]`: List of zeros.
231
+ """
232
+ eos = [self.eos_token_id]
233
+
234
+ if token_ids_1 is None:
235
+ return len(token_ids_0 + eos) * [0]
236
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f868398fc4e05ee1e8aeba95ddf18ddcc45b8bce55d5093bead5bbf80429b48b
3
+ size 1477754
tokenizer_config.json ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<unk>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "92538": {
28
+ "content": "<|plugin|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "92539": {
36
+ "content": "<|interpreter|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "92540": {
44
+ "content": "<|action_end|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "92541": {
52
+ "content": "<|action_start|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "92542": {
60
+ "content": "<|im_end|>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "92543": {
68
+ "content": "<|im_start|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ }
75
+ },
76
+ "additional_special_tokens": [
77
+ "<|im_start|>",
78
+ "<|im_end|>",
79
+ "<|action_start|>",
80
+ "<|action_end|>",
81
+ "<|interpreter|>",
82
+ "<|plugin|>"
83
+ ],
84
+ "auto_map": {
85
+ "AutoTokenizer": [
86
+ "tokenization_internlm2.InternLM2Tokenizer",
87
+ null
88
+ ]
89
+ },
90
+ "bos_token": "<s>",
91
+ "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
92
+ "clean_up_tokenization_spaces": false,
93
+ "eos_token": "</s>",
94
+ "model_max_length": 1000000000000000019884624838656,
95
+ "pad_token": "</s>",
96
+ "padding_side": "right",
97
+ "tokenizer_class": "InternLM2Tokenizer",
98
+ "unk_token": "<unk>"
99
+ }