Add batching support
#1
by
robertsonwang
- opened
- modeling_tinyllava_phi.py +105 -62
modeling_tinyllava_phi.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# For licensing see accompanying LICENSE file.
|
2 |
# Copyright (C) 2024 TinyLLaVA. All Rights Reserved.
|
3 |
import time
|
|
|
4 |
|
5 |
import dataclasses
|
6 |
from enum import auto, Enum
|
@@ -160,26 +161,40 @@ def process_images(images, image_processor, model_cfg):
|
|
160 |
return new_images
|
161 |
|
162 |
|
163 |
-
def tokenizer_image_token(
|
164 |
-
|
|
|
165 |
|
166 |
-
|
167 |
-
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
if return_tensors is not None:
|
179 |
if return_tensors == 'pt':
|
180 |
-
|
|
|
|
|
|
|
|
|
181 |
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
182 |
-
|
|
|
183 |
|
184 |
def load_image(image_file):
|
185 |
if image_file.startswith("http") or image_file.startswith("https"):
|
@@ -204,9 +219,9 @@ class Connector(nn.Module):
|
|
204 |
for _ in range(1, mlp_depth):
|
205 |
modules.append(ACT_TYPE[act_type]())
|
206 |
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
207 |
-
|
208 |
self._connector = nn.Sequential(*modules)
|
209 |
-
|
210 |
def forward(self, x):
|
211 |
return self._connector(x)
|
212 |
|
@@ -219,9 +234,9 @@ class VisionTower(nn.Module):
|
|
219 |
else:
|
220 |
self._vision_tower = SiglipVisionModel(cfg)
|
221 |
self._image_processor = SiglipImageProcessor.from_pretrained(cfg.model_name_or_path)
|
222 |
-
|
223 |
self.config = cfg
|
224 |
-
|
225 |
def forward(self, x, **kwargs):
|
226 |
image_features = self._vision_tower(x, output_hidden_states=True)
|
227 |
image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
|
@@ -234,11 +249,11 @@ class VisionTower(nn.Module):
|
|
234 |
raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
|
235 |
|
236 |
return image_features
|
237 |
-
|
238 |
@property
|
239 |
def vision_tower(self):
|
240 |
return self._vision_tower
|
241 |
-
|
242 |
@vision_tower.setter
|
243 |
def vision_tower(self, vision_tower):
|
244 |
self._vision_tower = vision_tower
|
@@ -248,7 +263,7 @@ def get_value_from_kwargs(kwargs, name):
|
|
248 |
return kwargs.pop(name)
|
249 |
else:
|
250 |
return None
|
251 |
-
|
252 |
|
253 |
class TinyLlavaPreTrainedModel(PreTrainedModel):
|
254 |
config_class = TinyLlavaConfig
|
@@ -284,7 +299,7 @@ class TinyLlavaPreTrainedModel(PreTrainedModel):
|
|
284 |
|
285 |
class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
286 |
def __init__(self, config: TinyLlavaConfig):
|
287 |
-
|
288 |
super().__init__(config)
|
289 |
|
290 |
self.language_model = PhiForCausalLM(config.text_config)
|
@@ -292,7 +307,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
|
292 |
self.connector = Connector(config)
|
293 |
self.post_init()
|
294 |
|
295 |
-
|
296 |
def get_input_embeddings(self):
|
297 |
return self.language_model.get_input_embeddings()
|
298 |
|
@@ -322,7 +337,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
|
322 |
self.vocab_size = model_embeds.num_embeddings
|
323 |
return model_embeds
|
324 |
|
325 |
-
|
326 |
def forward(
|
327 |
self,
|
328 |
input_ids: torch.LongTensor = None,
|
@@ -368,7 +383,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
|
368 |
output_hidden_states=output_hidden_states,
|
369 |
return_dict=return_dict
|
370 |
)
|
371 |
-
|
372 |
@torch.no_grad()
|
373 |
def generate(
|
374 |
self,
|
@@ -408,7 +423,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
|
408 |
inputs_embeds=inputs_embeds,
|
409 |
**kwargs
|
410 |
)
|
411 |
-
|
412 |
def encode_images(self, images):
|
413 |
kwargs = {}
|
414 |
kwargs['vision_feature_layer'] = self.config.vision_feature_layer
|
@@ -417,9 +432,9 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
|
417 |
image_features = self.vision_tower(images, **kwargs)
|
418 |
image_features = self.connector(image_features)
|
419 |
return image_features
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
424 |
inputs_embeds=None, **kwargs):
|
425 |
images = kwargs.pop("images", None)
|
@@ -432,7 +447,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
|
432 |
if image_sizes is not None:
|
433 |
inputs['image_sizes'] = image_sizes
|
434 |
return inputs
|
435 |
-
|
436 |
def prepare_inputs_labels_for_multimodal(
|
437 |
self, input_ids, position_ids, attention_mask, past_key_values, labels,
|
438 |
images, image_sizes=None
|
@@ -441,7 +456,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
|
441 |
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
442 |
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
443 |
|
444 |
-
|
445 |
image_features = self.encode_images(images)
|
446 |
|
447 |
# TODO: image start / end is not implemented here to support pretraining.
|
@@ -565,40 +580,72 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
|
565 |
position_ids = None
|
566 |
|
567 |
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
568 |
-
|
569 |
def chat(
|
570 |
self,
|
571 |
-
|
572 |
-
tokenizer
|
573 |
-
|
574 |
max_new_tokens: int = 512,
|
575 |
-
num_beams
|
576 |
top_p=None,
|
577 |
temperature=0
|
578 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
579 |
image_processor = self.vision_tower._image_processor
|
580 |
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
stime = time.time()
|
597 |
-
|
598 |
with torch.inference_mode():
|
599 |
output_ids = self.generate(
|
600 |
input_ids,
|
601 |
-
images=
|
602 |
do_sample=True if temperature > 0 else False,
|
603 |
temperature=temperature,
|
604 |
top_p=top_p,
|
@@ -606,19 +653,15 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
|
606 |
pad_token_id=tokenizer.pad_token_id,
|
607 |
max_new_tokens=max_new_tokens,
|
608 |
use_cache=True,
|
609 |
-
# stopping_criteria=[stopping_criteria],
|
610 |
)
|
611 |
-
|
612 |
-
# print('inference over')
|
613 |
generation_time = time.time() - stime
|
614 |
-
outputs = tokenizer.batch_decode(
|
615 |
-
output_ids, skip_special_tokens=True
|
616 |
-
)[0]
|
617 |
|
618 |
-
|
|
|
|
|
619 |
|
620 |
return outputs, generation_time
|
621 |
-
|
622 |
|
623 |
-
|
624 |
-
|
|
|
|
1 |
# For licensing see accompanying LICENSE file.
|
2 |
# Copyright (C) 2024 TinyLLaVA. All Rights Reserved.
|
3 |
import time
|
4 |
+
import numpy as np
|
5 |
|
6 |
import dataclasses
|
7 |
from enum import auto, Enum
|
|
|
161 |
return new_images
|
162 |
|
163 |
|
164 |
+
def tokenizer_image_token(prompts, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
165 |
+
def process_single_prompt(prompt):
|
166 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
167 |
|
168 |
+
def insert_separator(X, sep):
|
169 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
170 |
|
171 |
+
input_ids = []
|
172 |
+
offset = 0
|
173 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
174 |
+
offset = 1
|
175 |
+
input_ids.append(prompt_chunks[0][0])
|
176 |
|
177 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
178 |
+
input_ids.extend(x[offset:])
|
179 |
+
|
180 |
+
return input_ids
|
181 |
+
|
182 |
+
if isinstance(prompts, str): # Handle single prompt
|
183 |
+
return process_single_prompt(prompts)
|
184 |
+
|
185 |
+
# Handle batch of prompts
|
186 |
+
batch_input_ids = [process_single_prompt(prompt) for prompt in prompts]
|
187 |
|
188 |
if return_tensors is not None:
|
189 |
if return_tensors == 'pt':
|
190 |
+
max_length = max(len(ids) for ids in batch_input_ids)
|
191 |
+
padded_input_ids = [
|
192 |
+
ids + [tokenizer.pad_token_id] * (max_length - len(ids)) for ids in batch_input_ids
|
193 |
+
]
|
194 |
+
return torch.tensor(padded_input_ids, dtype=torch.long)
|
195 |
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
196 |
+
|
197 |
+
return batch_input_ids
|
198 |
|
199 |
def load_image(image_file):
|
200 |
if image_file.startswith("http") or image_file.startswith("https"):
|
|
|
219 |
for _ in range(1, mlp_depth):
|
220 |
modules.append(ACT_TYPE[act_type]())
|
221 |
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
222 |
+
|
223 |
self._connector = nn.Sequential(*modules)
|
224 |
+
|
225 |
def forward(self, x):
|
226 |
return self._connector(x)
|
227 |
|
|
|
234 |
else:
|
235 |
self._vision_tower = SiglipVisionModel(cfg)
|
236 |
self._image_processor = SiglipImageProcessor.from_pretrained(cfg.model_name_or_path)
|
237 |
+
|
238 |
self.config = cfg
|
239 |
+
|
240 |
def forward(self, x, **kwargs):
|
241 |
image_features = self._vision_tower(x, output_hidden_states=True)
|
242 |
image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
|
|
|
249 |
raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
|
250 |
|
251 |
return image_features
|
252 |
+
|
253 |
@property
|
254 |
def vision_tower(self):
|
255 |
return self._vision_tower
|
256 |
+
|
257 |
@vision_tower.setter
|
258 |
def vision_tower(self, vision_tower):
|
259 |
self._vision_tower = vision_tower
|
|
|
263 |
return kwargs.pop(name)
|
264 |
else:
|
265 |
return None
|
266 |
+
|
267 |
|
268 |
class TinyLlavaPreTrainedModel(PreTrainedModel):
|
269 |
config_class = TinyLlavaConfig
|
|
|
299 |
|
300 |
class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
301 |
def __init__(self, config: TinyLlavaConfig):
|
302 |
+
|
303 |
super().__init__(config)
|
304 |
|
305 |
self.language_model = PhiForCausalLM(config.text_config)
|
|
|
307 |
self.connector = Connector(config)
|
308 |
self.post_init()
|
309 |
|
310 |
+
|
311 |
def get_input_embeddings(self):
|
312 |
return self.language_model.get_input_embeddings()
|
313 |
|
|
|
337 |
self.vocab_size = model_embeds.num_embeddings
|
338 |
return model_embeds
|
339 |
|
340 |
+
|
341 |
def forward(
|
342 |
self,
|
343 |
input_ids: torch.LongTensor = None,
|
|
|
383 |
output_hidden_states=output_hidden_states,
|
384 |
return_dict=return_dict
|
385 |
)
|
386 |
+
|
387 |
@torch.no_grad()
|
388 |
def generate(
|
389 |
self,
|
|
|
423 |
inputs_embeds=inputs_embeds,
|
424 |
**kwargs
|
425 |
)
|
426 |
+
|
427 |
def encode_images(self, images):
|
428 |
kwargs = {}
|
429 |
kwargs['vision_feature_layer'] = self.config.vision_feature_layer
|
|
|
432 |
image_features = self.vision_tower(images, **kwargs)
|
433 |
image_features = self.connector(image_features)
|
434 |
return image_features
|
435 |
+
|
436 |
+
|
437 |
+
|
438 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
439 |
inputs_embeds=None, **kwargs):
|
440 |
images = kwargs.pop("images", None)
|
|
|
447 |
if image_sizes is not None:
|
448 |
inputs['image_sizes'] = image_sizes
|
449 |
return inputs
|
450 |
+
|
451 |
def prepare_inputs_labels_for_multimodal(
|
452 |
self, input_ids, position_ids, attention_mask, past_key_values, labels,
|
453 |
images, image_sizes=None
|
|
|
456 |
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
457 |
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
458 |
|
459 |
+
|
460 |
image_features = self.encode_images(images)
|
461 |
|
462 |
# TODO: image start / end is not implemented here to support pretraining.
|
|
|
580 |
position_ids = None
|
581 |
|
582 |
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
583 |
+
|
584 |
def chat(
|
585 |
self,
|
586 |
+
prompts: Union[list, str],
|
587 |
+
tokenizer=None,
|
588 |
+
images: Union[list, str] = None,
|
589 |
max_new_tokens: int = 512,
|
590 |
+
num_beams=1,
|
591 |
top_p=None,
|
592 |
temperature=0
|
593 |
):
|
594 |
+
"""
|
595 |
+
Generate responses for a batch of prompts.
|
596 |
+
|
597 |
+
Args:
|
598 |
+
prompts (list): List of text prompts.
|
599 |
+
tokenizer: Tokenizer object.
|
600 |
+
images (list, optional): List of image file paths corresponding to the prompts. Defaults to None.
|
601 |
+
max_new_tokens (int): Maximum number of new tokens to generate. Defaults to 512.
|
602 |
+
num_beams (int): Number of beams for beam search. Defaults to 1.
|
603 |
+
top_p (float, optional): Nucleus sampling probability. Defaults to None.
|
604 |
+
temperature (float): Sampling temperature. Defaults to 0.
|
605 |
+
|
606 |
+
Returns:
|
607 |
+
list: List of generated outputs.
|
608 |
+
list: List of generation times for each batch.
|
609 |
+
"""
|
610 |
+
if isinstance(prompts, list) and isinstance(images, str):
|
611 |
+
assert len(prompts) == len(images) or images is None, "Mismatch between prompts and images."
|
612 |
+
else:
|
613 |
+
prompts = [prompts]
|
614 |
+
images = [images]
|
615 |
image_processor = self.vision_tower._image_processor
|
616 |
|
617 |
+
# Prepare inputs
|
618 |
+
input_texts = []
|
619 |
+
image_tensors = None
|
620 |
+
|
621 |
+
for i, prompt in enumerate(prompts):
|
622 |
+
if images and images[i] is not None:
|
623 |
+
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
|
624 |
+
conv = conv_phi_v0.copy()
|
625 |
+
|
626 |
+
conv.append_message(conv.roles[0], prompt)
|
627 |
+
conv.append_message(conv.roles[1], None)
|
628 |
+
input_texts.append(conv.get_prompt())
|
629 |
+
|
630 |
+
# Tokenize prompts
|
631 |
+
input_ids = tokenizer_image_token(
|
632 |
+
input_texts, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
633 |
+
).to(self.device)
|
634 |
+
|
635 |
+
# Process images
|
636 |
+
if images:
|
637 |
+
processed_images = [
|
638 |
+
process_images(load_image(image), image_processor, self.config)
|
639 |
+
for image in images if image is not None
|
640 |
+
]
|
641 |
+
image_tensors = torch.stack(processed_images).to(self.device).squeeze(1)
|
642 |
+
|
643 |
+
# Generate responses
|
644 |
stime = time.time()
|
|
|
645 |
with torch.inference_mode():
|
646 |
output_ids = self.generate(
|
647 |
input_ids,
|
648 |
+
images=image_tensors,
|
649 |
do_sample=True if temperature > 0 else False,
|
650 |
temperature=temperature,
|
651 |
top_p=top_p,
|
|
|
653 |
pad_token_id=tokenizer.pad_token_id,
|
654 |
max_new_tokens=max_new_tokens,
|
655 |
use_cache=True,
|
|
|
656 |
)
|
|
|
|
|
657 |
generation_time = time.time() - stime
|
|
|
|
|
|
|
658 |
|
659 |
+
# Decode outputs
|
660 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
661 |
+
outputs = [output.strip() for output in outputs]
|
662 |
|
663 |
return outputs, generation_time
|
|
|
664 |
|
665 |
+
|
666 |
+
AutoConfig.register("tinyllava", TinyLlavaConfig)
|
667 |
+
AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)
|