Upload folder using huggingface_hub
Browse files- geopixel.py +13 -6
geopixel.py
CHANGED
@@ -12,6 +12,7 @@ from model.IXC.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM
|
|
12 |
from model.IXC.modeling_internlm2 import InternLM2Model
|
13 |
from model.sam2.build_sam import build_sam2_hf
|
14 |
from model.sam2.utils.transforms import SAM2Transforms
|
|
|
15 |
try:
|
16 |
from transformers.generation.streamers import BaseStreamer
|
17 |
except: # noqa # pylint: disable=bare-except
|
@@ -93,8 +94,10 @@ class GeoPixelMetaModel:
|
|
93 |
(128, 128),
|
94 |
(64, 64),
|
95 |
]
|
|
|
96 |
for param in self.visual_model.parameters():
|
97 |
param.requires_grad = False
|
|
|
98 |
if config.train_mask_decoder:
|
99 |
self.visual_model.sam_mask_decoder.train()
|
100 |
for param in self.visual_model.sam_mask_decoder.parameters():
|
@@ -195,6 +198,8 @@ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
|
|
195 |
samples = kwargs.get('samples', None)
|
196 |
if samples and samples['data_type'][0] == 'grounding':
|
197 |
kwargs['output_hidden_states'] = True
|
|
|
|
|
198 |
torch.cuda.empty_cache()
|
199 |
outputs = super().forward(**kwargs)
|
200 |
|
@@ -246,9 +251,6 @@ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
|
|
246 |
low_res_masks,
|
247 |
ori_hw[i],
|
248 |
)
|
249 |
-
|
250 |
-
# pred_masks = pred_masks.squeeze(0)
|
251 |
-
# all_pred_masks.append(pred_masks)
|
252 |
all_pred_masks.append(pred_masks[:, 0])
|
253 |
|
254 |
|
@@ -320,27 +322,32 @@ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
|
|
320 |
hd_num: int = 9,
|
321 |
history: List[Tuple[str, str]] = [],
|
322 |
max_new_tokens: int = 1024,
|
|
|
323 |
**kwargs,
|
324 |
):
|
325 |
with torch.no_grad():
|
326 |
inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num)
|
327 |
-
print(im_mask.sum().item())
|
328 |
inputs = {
|
329 |
k: v.to(self.device)
|
330 |
for k, v in inputs.items() if torch.is_tensor(v)
|
331 |
}
|
332 |
-
# print(len(inputs['inputs_embeds'][0]))
|
333 |
eos_token_id = [
|
334 |
tokenizer.eos_token_id,
|
335 |
#tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
|
336 |
]
|
337 |
all_pred_masks = []
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
outputs = self.generate(
|
339 |
**inputs,
|
340 |
max_new_tokens=max_new_tokens,
|
341 |
im_mask=im_mask,
|
342 |
input_ids = None,
|
343 |
-
streamer=
|
344 |
num_beams=1,
|
345 |
do_sample=False,
|
346 |
temperature=1.0,
|
|
|
12 |
from model.IXC.modeling_internlm2 import InternLM2Model
|
13 |
from model.sam2.build_sam import build_sam2_hf
|
14 |
from model.sam2.utils.transforms import SAM2Transforms
|
15 |
+
from transformers import TextStreamer
|
16 |
try:
|
17 |
from transformers.generation.streamers import BaseStreamer
|
18 |
except: # noqa # pylint: disable=bare-except
|
|
|
94 |
(128, 128),
|
95 |
(64, 64),
|
96 |
]
|
97 |
+
|
98 |
for param in self.visual_model.parameters():
|
99 |
param.requires_grad = False
|
100 |
+
|
101 |
if config.train_mask_decoder:
|
102 |
self.visual_model.sam_mask_decoder.train()
|
103 |
for param in self.visual_model.sam_mask_decoder.parameters():
|
|
|
198 |
samples = kwargs.get('samples', None)
|
199 |
if samples and samples['data_type'][0] == 'grounding':
|
200 |
kwargs['output_hidden_states'] = True
|
201 |
+
kwargs['use_cache'] = False
|
202 |
+
|
203 |
torch.cuda.empty_cache()
|
204 |
outputs = super().forward(**kwargs)
|
205 |
|
|
|
251 |
low_res_masks,
|
252 |
ori_hw[i],
|
253 |
)
|
|
|
|
|
|
|
254 |
all_pred_masks.append(pred_masks[:, 0])
|
255 |
|
256 |
|
|
|
322 |
hd_num: int = 9,
|
323 |
history: List[Tuple[str, str]] = [],
|
324 |
max_new_tokens: int = 1024,
|
325 |
+
stream: bool = False,
|
326 |
**kwargs,
|
327 |
):
|
328 |
with torch.no_grad():
|
329 |
inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num)
|
|
|
330 |
inputs = {
|
331 |
k: v.to(self.device)
|
332 |
for k, v in inputs.items() if torch.is_tensor(v)
|
333 |
}
|
|
|
334 |
eos_token_id = [
|
335 |
tokenizer.eos_token_id,
|
336 |
#tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
|
337 |
]
|
338 |
all_pred_masks = []
|
339 |
+
|
340 |
+
if stream:
|
341 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
342 |
+
else:
|
343 |
+
streamer = None
|
344 |
+
|
345 |
outputs = self.generate(
|
346 |
**inputs,
|
347 |
max_new_tokens=max_new_tokens,
|
348 |
im_mask=im_mask,
|
349 |
input_ids = None,
|
350 |
+
streamer= streamer,
|
351 |
num_beams=1,
|
352 |
do_sample=False,
|
353 |
temperature=1.0,
|