PyTorch
English
internlm2
code
custom_code
AkashahS commited on
Commit
bc762dc
·
verified ·
1 Parent(s): 0d9b822

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. geopixel.py +13 -6
geopixel.py CHANGED
@@ -12,6 +12,7 @@ from model.IXC.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM
12
  from model.IXC.modeling_internlm2 import InternLM2Model
13
  from model.sam2.build_sam import build_sam2_hf
14
  from model.sam2.utils.transforms import SAM2Transforms
 
15
  try:
16
  from transformers.generation.streamers import BaseStreamer
17
  except: # noqa # pylint: disable=bare-except
@@ -93,8 +94,10 @@ class GeoPixelMetaModel:
93
  (128, 128),
94
  (64, 64),
95
  ]
 
96
  for param in self.visual_model.parameters():
97
  param.requires_grad = False
 
98
  if config.train_mask_decoder:
99
  self.visual_model.sam_mask_decoder.train()
100
  for param in self.visual_model.sam_mask_decoder.parameters():
@@ -195,6 +198,8 @@ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
195
  samples = kwargs.get('samples', None)
196
  if samples and samples['data_type'][0] == 'grounding':
197
  kwargs['output_hidden_states'] = True
 
 
198
  torch.cuda.empty_cache()
199
  outputs = super().forward(**kwargs)
200
 
@@ -246,9 +251,6 @@ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
246
  low_res_masks,
247
  ori_hw[i],
248
  )
249
-
250
- # pred_masks = pred_masks.squeeze(0)
251
- # all_pred_masks.append(pred_masks)
252
  all_pred_masks.append(pred_masks[:, 0])
253
 
254
 
@@ -320,27 +322,32 @@ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
320
  hd_num: int = 9,
321
  history: List[Tuple[str, str]] = [],
322
  max_new_tokens: int = 1024,
 
323
  **kwargs,
324
  ):
325
  with torch.no_grad():
326
  inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num)
327
- print(im_mask.sum().item())
328
  inputs = {
329
  k: v.to(self.device)
330
  for k, v in inputs.items() if torch.is_tensor(v)
331
  }
332
- # print(len(inputs['inputs_embeds'][0]))
333
  eos_token_id = [
334
  tokenizer.eos_token_id,
335
  #tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
336
  ]
337
  all_pred_masks = []
 
 
 
 
 
 
338
  outputs = self.generate(
339
  **inputs,
340
  max_new_tokens=max_new_tokens,
341
  im_mask=im_mask,
342
  input_ids = None,
343
- streamer= None,
344
  num_beams=1,
345
  do_sample=False,
346
  temperature=1.0,
 
12
  from model.IXC.modeling_internlm2 import InternLM2Model
13
  from model.sam2.build_sam import build_sam2_hf
14
  from model.sam2.utils.transforms import SAM2Transforms
15
+ from transformers import TextStreamer
16
  try:
17
  from transformers.generation.streamers import BaseStreamer
18
  except: # noqa # pylint: disable=bare-except
 
94
  (128, 128),
95
  (64, 64),
96
  ]
97
+
98
  for param in self.visual_model.parameters():
99
  param.requires_grad = False
100
+
101
  if config.train_mask_decoder:
102
  self.visual_model.sam_mask_decoder.train()
103
  for param in self.visual_model.sam_mask_decoder.parameters():
 
198
  samples = kwargs.get('samples', None)
199
  if samples and samples['data_type'][0] == 'grounding':
200
  kwargs['output_hidden_states'] = True
201
+ kwargs['use_cache'] = False
202
+
203
  torch.cuda.empty_cache()
204
  outputs = super().forward(**kwargs)
205
 
 
251
  low_res_masks,
252
  ori_hw[i],
253
  )
 
 
 
254
  all_pred_masks.append(pred_masks[:, 0])
255
 
256
 
 
322
  hd_num: int = 9,
323
  history: List[Tuple[str, str]] = [],
324
  max_new_tokens: int = 1024,
325
+ stream: bool = False,
326
  **kwargs,
327
  ):
328
  with torch.no_grad():
329
  inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num)
 
330
  inputs = {
331
  k: v.to(self.device)
332
  for k, v in inputs.items() if torch.is_tensor(v)
333
  }
 
334
  eos_token_id = [
335
  tokenizer.eos_token_id,
336
  #tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
337
  ]
338
  all_pred_masks = []
339
+
340
+ if stream:
341
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
342
+ else:
343
+ streamer = None
344
+
345
  outputs = self.generate(
346
  **inputs,
347
  max_new_tokens=max_new_tokens,
348
  im_mask=im_mask,
349
  input_ids = None,
350
+ streamer= streamer,
351
  num_beams=1,
352
  do_sample=False,
353
  temperature=1.0,