PyTorch
English
internlm2
code
custom_code
AkashahS commited on
Commit
0d9b822
·
verified ·
1 Parent(s): ad60165

Delete geopix.py

Browse files
Files changed (1) hide show
  1. geopix.py +0 -410
geopix.py DELETED
@@ -1,410 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import os
4
- import torch
5
- import numpy as np
6
- import torch.nn as nn
7
- import matplotlib.pyplot as plt
8
- from PIL import Image
9
- import torch.nn.functional as F
10
- from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from model.IXC.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM
12
- from model.IXC.modeling_internlm2 import InternLM2Model
13
- from model.sam2.build_sam import build_sam2_hf
14
- from model.sam2.utils.transforms import SAM2Transforms
15
- try:
16
- from transformers.generation.streamers import BaseStreamer
17
- except: # noqa # pylint: disable=bare-except
18
- BaseStreamer = None
19
-
20
-
21
- def dice_loss(
22
- inputs: torch.Tensor,
23
- targets: torch.Tensor,
24
- num_masks: float,
25
- scale=1000, # 100000.0,
26
- eps=1e-6,
27
- ):
28
- """
29
- Compute the DICE loss, similar to generalized IOU for masks
30
- Args:
31
- inputs: A float tensor of arbitrary shape.
32
- The predictions for each example.
33
- targets: A float tensor with the same shape as inputs. Stores the binary
34
- classification label for each element in inputs
35
- (0 for the negative class and 1 for the positive class).
36
- """
37
- inputs = inputs.sigmoid()
38
- inputs = inputs.flatten(1, 2)
39
- targets = targets.flatten(1, 2)
40
- numerator = 2 * (inputs / scale * targets).sum(-1)
41
- denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
42
- loss = 1 - (numerator + eps) / (denominator + eps)
43
- loss = loss.sum() / (num_masks + 1e-8)
44
- return loss
45
-
46
-
47
- def sigmoid_ce_loss(
48
- inputs: torch.Tensor,
49
- targets: torch.Tensor,
50
- num_masks: float,
51
- ):
52
- """
53
- Args:
54
- inputs: A float tensor of arbitrary shape.
55
- The predictions for each example.
56
- targets: A float tensor with the same shape as inputs. Stores the binary
57
- classification label for each element in inputs
58
- (0 for the negative class and 1 for the positive class).
59
- Returns:
60
- Loss tensor
61
- """
62
- loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
63
- loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
64
- return loss
65
-
66
-
67
- class GeoPixMetaModel:
68
- def __init__(
69
- self,
70
- config,
71
- **kwargs,
72
- ):
73
- super(GeoPixMetaModel, self).__init__(config)
74
- self.config = config
75
- self.config.train_mask_decoder = getattr(self.config, "train_mask_decoder", kwargs.get("train_mask_decoder", False))
76
- self.config.out_dim = getattr(self.config, "out_dim", kwargs.get("out_dim", 256))
77
- self.vision_pretrained = kwargs.get("vision_pretrained", None)
78
- self.initialize_geopix_modules(self.config)
79
-
80
- def initialize_geopix_modules(self, config):
81
- # grounding vision model
82
- self.visual_model = build_sam2_hf(self.vision_pretrained)
83
-
84
- self._transform = SAM2Transforms(
85
- resolution=self.visual_model.image_size,
86
- mask_threshold=0.0,
87
- max_hole_area=0.0,
88
- max_sprinkle_area=0.0,
89
- )
90
- # Spatial dim for backbone feature maps
91
- self._bb_feat_sizes = [
92
- (256, 256),
93
- (128, 128),
94
- (64, 64),
95
- ]
96
- for param in self.visual_model.parameters():
97
- param.requires_grad = False
98
- if config.train_mask_decoder:
99
- self.visual_model.sam_mask_decoder.train()
100
- for param in self.visual_model.sam_mask_decoder.parameters():
101
- param.requires_grad = True
102
-
103
- # text projection layer
104
- in_dim = config.hidden_size
105
- out_dim = config.out_dim
106
- text_projection_layers = [
107
- nn.Linear(in_dim, in_dim),
108
- nn.ReLU(inplace=True),
109
- nn.Linear(in_dim, out_dim),
110
- nn.Dropout(0.0),
111
- ]
112
- self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_projection_layers)])
113
- self.text_hidden_fcs.train()
114
- for param in self.text_hidden_fcs.parameters():
115
- param.requires_grad = True
116
-
117
-
118
- class GeoPixModel(GeoPixMetaModel, InternLM2Model):
119
- def __init__(
120
- self,
121
- config,
122
- **kwargs,
123
- ):
124
- super(GeoPixModel, self).__init__(config, **kwargs)
125
- self.config.use_cache = False
126
-
127
-
128
- class GeoPixForCausalLM(InternLMXComposer2ForCausalLM):
129
- def __init__(self,config,**kwargs,):
130
-
131
- self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
132
- self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
133
- self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
134
- self.seg_token_idx = kwargs.pop("seg_token_idx")
135
-
136
- super().__init__(config)
137
- self.model = GeoPixModel(config, **kwargs)
138
- self.vocab_size = config.vocab_size
139
- self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
140
- self.post_init()
141
-
142
- def encode_g_img(self, image):
143
- """
144
- Calculates the image embeddings for the provided image
145
- Arguments:
146
- image (np.ndarray or str)
147
- """
148
- if image is None:
149
- return None
150
- if isinstance(image, str):
151
- _, ext = os.path.splitext(image)
152
- if ext.lower() in {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}:
153
- image = Image.open(image)
154
- w, h = image.size
155
- _orig_hw = [(h, w)]
156
- else:
157
- print ('Unknow input format', image)
158
- return None
159
- else:
160
- assert isinstance(image, torch.Tensor)
161
- _orig_hw = [image.shape[:2]]
162
- image = self.model._transform(image)
163
- image = image[None, ...].to(self.device)
164
- assert ( len(image.shape) == 4 and image.shape[1] == 3), f"image must be of size 1x3xHxW, got {image.shape}"
165
- features = self.get_visual_embs(image)
166
- return features,_orig_hw
167
-
168
- def get_visual_embs(self, img_batch: torch.FloatTensor):
169
- with torch.no_grad():
170
- torch.cuda.empty_cache()
171
- img_batch = img_batch.to(self.device)
172
- batch_size = img_batch.shape[0]
173
- assert (
174
- len(img_batch.shape) == 4 and img_batch.shape[1] == 3
175
- ), f"grounding_img_batch must be of size Bx3xHxW, got {img_batch.shape}"
176
- backbone_out = self.model.visual_model.forward_image(img_batch)
177
- _, vision_feats, _, _ = self.model.visual_model._prepare_backbone_features(backbone_out)
178
- if self.model.visual_model.directly_add_no_mem_embed:
179
- vision_feats[-1] = vision_feats[-1] + self.model.visual_model.no_mem_embed
180
- feats = [
181
- feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
182
- for feat, feat_size in zip(vision_feats[::-1], self.model._bb_feat_sizes[::-1])
183
- ][::-1]
184
- features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
185
- return features
186
-
187
- def forward(self, **kwargs):
188
- return super().forward(**kwargs) if "past_key_values" in kwargs else self.model_forward(**kwargs)
189
-
190
- def model_forward(
191
- self,
192
- inference: bool = False,
193
- **kwargs,
194
- ):
195
- samples = kwargs.get('samples', None)
196
- if samples and samples['data_type'][0] == 'grounding':
197
- kwargs['output_hidden_states'] = True
198
- torch.cuda.empty_cache()
199
- outputs = super().forward(**kwargs)
200
-
201
- if inference:
202
- assert len(samples['text_input']) == 1 and len(samples['image'][0]) == 1 #single image and single query
203
- output_hidden_states = [outputs.hidden_states]
204
- outputs = None
205
- else:
206
- output_hidden_states = outputs.hidden_states
207
-
208
- hidden_states = []
209
- assert len(self.model.text_hidden_fcs) == 1
210
- hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))
211
- last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
212
-
213
- seg_token_mask = outputs.seg_token_mask
214
- pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)]
215
- image_g_batch = torch.cat(samples['image_g'][0],dim = 0)
216
- image_g_features = self.get_visual_embs(image_g_batch)
217
- ori_hw = samples['ori_hw'][0]
218
- all_pred_masks = []
219
- for i in range(len(pred_embeddings)): #(bs,)
220
- if (pred_embeddings[i].numel()== 0):
221
- pred_masks.append([])
222
- continue
223
- (sparse_embeddings, dense_embeddings,) = self.model.visual_model.sam_prompt_encoder(
224
- points=None,
225
- boxes=None,
226
- masks=None,
227
- text_embeds=pred_embeddings[i].unsqueeze(1),
228
- )
229
- batch_mode = (pred_embeddings[i].shape[0]>1)
230
- high_res_features = [
231
- feat_level[i].unsqueeze(0)
232
- for feat_level in image_g_features["high_res_feats"]
233
- ]
234
- sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
235
- image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16)
236
- low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder(
237
- image_embeddings=image_g_embeds,
238
- image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(),
239
- sparse_prompt_embeddings=sparse_embeddings,
240
- dense_prompt_embeddings=dense_embeddings,
241
- repeat_image=batch_mode,
242
- multimask_output=False,
243
- high_res_features=high_res_features,
244
- )
245
- pred_masks = self.model._transform.postprocess_masks(
246
- low_res_masks,
247
- ori_hw[i],
248
- )
249
-
250
- # pred_masks = pred_masks.squeeze(0)
251
- # all_pred_masks.append(pred_masks)
252
- all_pred_masks.append(pred_masks[:, 0])
253
-
254
-
255
- model_output = outputs
256
- gt_masks = samples['masks'][0]
257
- pred_masks = all_pred_masks
258
-
259
- if inference:
260
- return {
261
- "pred_masks": pred_masks,
262
- "gt_masks": gt_masks,
263
- }
264
-
265
- ce_loss = model_output.loss
266
- ce_loss = ce_loss * self.ce_loss_weight
267
- mask_bce_loss = 0
268
- mask_dice_loss = 0
269
- num_masks = 0
270
-
271
- for batch_idx in range(len(pred_masks)): # for every image
272
- cur_gt_masks = torch.stack(
273
- [
274
- torch.from_numpy(gt_mask).to(dtype=pred_masks[batch_idx].dtype, device=pred_masks[batch_idx].device)
275
- for gt_mask in gt_masks[batch_idx]
276
- ],
277
- dim=0
278
- ) # expected (bs,H,W)
279
- cur_pred_masks = pred_masks[batch_idx]
280
- assert (
281
- cur_gt_masks.shape[0] == cur_pred_masks.shape[0]
282
- ), "gt_masks.shape: {}, pred_masks.shape: {}".format(
283
- cur_gt_masks.shape, cur_pred_masks.shape
284
- )
285
- mask_bce_loss += (
286
- sigmoid_ce_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0])
287
- * cur_gt_masks.shape[0]
288
- )
289
- mask_dice_loss += (
290
- dice_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0])
291
- * cur_gt_masks.shape[0]
292
- )
293
- num_masks += cur_gt_masks.shape[0]
294
-
295
- mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
296
- mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
297
- mask_loss = mask_bce_loss + mask_dice_loss
298
-
299
- loss = ce_loss + mask_loss
300
- outputs = CausalLMOutputWithPast(
301
- loss=loss,
302
- logits=model_output.logits,
303
- past_key_values=model_output.past_key_values,
304
- hidden_states=output_hidden_states,
305
- attentions=model_output.attentions,
306
- )
307
- outputs.ce_loss = ce_loss
308
- outputs.mask_bce_loss = mask_bce_loss
309
- outputs.mask_dice_loss = mask_dice_loss
310
- outputs.mask_loss = mask_loss
311
- else:
312
- outputs = super().forward(**kwargs)
313
- return outputs
314
-
315
- def evaluate(
316
- self,
317
- tokenizer,
318
- query: str,
319
- images: List[Tuple[str, str]] = [],
320
- hd_num: int = 9,
321
- history: List[Tuple[str, str]] = [],
322
- max_new_tokens: int = 1024,
323
- **kwargs,
324
- ):
325
- with torch.no_grad():
326
- inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num)
327
- inputs = {
328
- k: v.to(self.device)
329
- for k, v in inputs.items() if torch.is_tensor(v)
330
- }
331
- # print(len(inputs['inputs_embeds'][0]))
332
- eos_token_id = [
333
- tokenizer.eos_token_id,
334
- #tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
335
- ]
336
- all_pred_masks = []
337
- outputs = self.generate(
338
- **inputs,
339
- max_new_tokens=max_new_tokens,
340
- im_mask=im_mask,
341
- input_ids = None,
342
- streamer= None,
343
- num_beams=1,
344
- do_sample=False,
345
- temperature=1.0,
346
- top_p= 1.0,
347
- top_k = 0,
348
- eos_token_id=eos_token_id,
349
- repetition_penalty=1.0,
350
- infer_mode = 'base',
351
- output_hidden_states=True,
352
- return_dict_in_generate=True,
353
- **kwargs,
354
- )
355
- output_ids = outputs['sequences']
356
- response = tokenizer.decode(output_ids[0].cpu().tolist(), skip_special_tokens=True)
357
- response = response.replace("[UNUSED_TOKEN_145]","")
358
- history = history + [(query, response)]
359
- if len(images)==1 and isinstance(images[0], str):
360
- output_hidden_states = outputs.hidden_states[-1]
361
- seg_token_mask = output_ids[:, 1:-1] == self.seg_token_idx
362
- inputs_embeds_len = inputs['inputs_embeds'].size(1)
363
- seg_token_mask = torch.cat(
364
- [
365
- torch.zeros((seg_token_mask.shape[0], inputs_embeds_len)).bool().cuda(),
366
- seg_token_mask,
367
- ],
368
- dim=1,
369
- )
370
- hidden_states = []
371
- assert len(self.model.text_hidden_fcs) == 1
372
- hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))
373
- last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
374
- pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)]
375
- image_g_features, ori_hw = self.encode_g_img(images[0])
376
-
377
- for i in range(len(pred_embeddings)):
378
- if (pred_embeddings[i].numel()== 0):
379
- all_pred_masks.append([])
380
- continue
381
- (sparse_embeddings,dense_embeddings,) = self.model.visual_model.sam_prompt_encoder(
382
- points=None,
383
- boxes=None,
384
- masks=None,
385
- text_embeds=pred_embeddings[i].unsqueeze(1),
386
- )
387
- batch_mode = (pred_embeddings[i].shape[0]>1)
388
- high_res_features = [
389
- feat_level[i].unsqueeze(0)
390
- for feat_level in image_g_features["high_res_feats"]
391
- ]
392
- sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
393
- image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16)
394
-
395
- low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder(
396
- image_embeddings=image_g_embeds,
397
- image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(),
398
- sparse_prompt_embeddings=sparse_embeddings,
399
- dense_prompt_embeddings=dense_embeddings,
400
- repeat_image=batch_mode,
401
- multimask_output=False,
402
- high_res_features=high_res_features,
403
- )
404
- pred_masks = self.model._transform.postprocess_masks(
405
- low_res_masks,
406
- ori_hw[i],
407
- )
408
- all_pred_masks.append(pred_masks[:, 0])
409
-
410
- return response, all_pred_masks