Files changed (1) hide show
  1. modeling_tinyllava_phi.py +105 -62
modeling_tinyllava_phi.py CHANGED
@@ -1,6 +1,7 @@
1
  # For licensing see accompanying LICENSE file.
2
  # Copyright (C) 2024 TinyLLaVA. All Rights Reserved.
3
  import time
 
4
 
5
  import dataclasses
6
  from enum import auto, Enum
@@ -160,26 +161,40 @@ def process_images(images, image_processor, model_cfg):
160
  return new_images
161
 
162
 
163
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
164
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
 
165
 
166
- def insert_separator(X, sep):
167
- return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
168
 
169
- input_ids = []
170
- offset = 0
171
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
172
- offset = 1
173
- input_ids.append(prompt_chunks[0][0])
174
 
175
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
176
- input_ids.extend(x[offset:])
 
 
 
 
 
 
 
 
177
 
178
  if return_tensors is not None:
179
  if return_tensors == 'pt':
180
- return torch.tensor(input_ids, dtype=torch.long)
 
 
 
 
181
  raise ValueError(f'Unsupported tensor type: {return_tensors}')
182
- return input_ids
 
183
 
184
  def load_image(image_file):
185
  if image_file.startswith("http") or image_file.startswith("https"):
@@ -204,9 +219,9 @@ class Connector(nn.Module):
204
  for _ in range(1, mlp_depth):
205
  modules.append(ACT_TYPE[act_type]())
206
  modules.append(nn.Linear(config.hidden_size, config.hidden_size))
207
-
208
  self._connector = nn.Sequential(*modules)
209
-
210
  def forward(self, x):
211
  return self._connector(x)
212
 
@@ -219,9 +234,9 @@ class VisionTower(nn.Module):
219
  else:
220
  self._vision_tower = SiglipVisionModel(cfg)
221
  self._image_processor = SiglipImageProcessor.from_pretrained(cfg.model_name_or_path)
222
-
223
  self.config = cfg
224
-
225
  def forward(self, x, **kwargs):
226
  image_features = self._vision_tower(x, output_hidden_states=True)
227
  image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
@@ -234,11 +249,11 @@ class VisionTower(nn.Module):
234
  raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
235
 
236
  return image_features
237
-
238
  @property
239
  def vision_tower(self):
240
  return self._vision_tower
241
-
242
  @vision_tower.setter
243
  def vision_tower(self, vision_tower):
244
  self._vision_tower = vision_tower
@@ -248,7 +263,7 @@ def get_value_from_kwargs(kwargs, name):
248
  return kwargs.pop(name)
249
  else:
250
  return None
251
-
252
 
253
  class TinyLlavaPreTrainedModel(PreTrainedModel):
254
  config_class = TinyLlavaConfig
@@ -284,7 +299,7 @@ class TinyLlavaPreTrainedModel(PreTrainedModel):
284
 
285
  class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
286
  def __init__(self, config: TinyLlavaConfig):
287
-
288
  super().__init__(config)
289
 
290
  self.language_model = PhiForCausalLM(config.text_config)
@@ -292,7 +307,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
292
  self.connector = Connector(config)
293
  self.post_init()
294
 
295
-
296
  def get_input_embeddings(self):
297
  return self.language_model.get_input_embeddings()
298
 
@@ -322,7 +337,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
322
  self.vocab_size = model_embeds.num_embeddings
323
  return model_embeds
324
 
325
-
326
  def forward(
327
  self,
328
  input_ids: torch.LongTensor = None,
@@ -368,7 +383,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
368
  output_hidden_states=output_hidden_states,
369
  return_dict=return_dict
370
  )
371
-
372
  @torch.no_grad()
373
  def generate(
374
  self,
@@ -408,7 +423,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
408
  inputs_embeds=inputs_embeds,
409
  **kwargs
410
  )
411
-
412
  def encode_images(self, images):
413
  kwargs = {}
414
  kwargs['vision_feature_layer'] = self.config.vision_feature_layer
@@ -417,9 +432,9 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
417
  image_features = self.vision_tower(images, **kwargs)
418
  image_features = self.connector(image_features)
419
  return image_features
420
-
421
-
422
-
423
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
424
  inputs_embeds=None, **kwargs):
425
  images = kwargs.pop("images", None)
@@ -432,7 +447,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
432
  if image_sizes is not None:
433
  inputs['image_sizes'] = image_sizes
434
  return inputs
435
-
436
  def prepare_inputs_labels_for_multimodal(
437
  self, input_ids, position_ids, attention_mask, past_key_values, labels,
438
  images, image_sizes=None
@@ -441,7 +456,7 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
441
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
442
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
443
 
444
-
445
  image_features = self.encode_images(images)
446
 
447
  # TODO: image start / end is not implemented here to support pretraining.
@@ -565,40 +580,72 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
565
  position_ids = None
566
 
567
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
568
-
569
  def chat(
570
  self,
571
- prompt: str,
572
- tokenizer = None,
573
- image: str = None,
574
  max_new_tokens: int = 512,
575
- num_beams = 1,
576
  top_p=None,
577
  temperature=0
578
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  image_processor = self.vision_tower._image_processor
580
 
581
- if image is not None:
582
- prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
583
- conv = conv_phi_v0.copy()
584
- conv.append_message(conv.roles[0], prompt)
585
- conv.append_message(conv.roles[1], None)
586
- prompt = conv.get_prompt()
587
- if image is not None:
588
- image = load_image(image)
589
- image_tensor = process_images(image, image_processor, self.config).to(self.device)
590
-
591
- input_ids = (
592
- tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
593
- .unsqueeze(0).to(self.device)
594
- )
595
- # Generate
 
 
 
 
 
 
 
 
 
 
 
 
596
  stime = time.time()
597
-
598
  with torch.inference_mode():
599
  output_ids = self.generate(
600
  input_ids,
601
- images=image_tensor,
602
  do_sample=True if temperature > 0 else False,
603
  temperature=temperature,
604
  top_p=top_p,
@@ -606,19 +653,15 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
606
  pad_token_id=tokenizer.pad_token_id,
607
  max_new_tokens=max_new_tokens,
608
  use_cache=True,
609
- # stopping_criteria=[stopping_criteria],
610
  )
611
-
612
- # print('inference over')
613
  generation_time = time.time() - stime
614
- outputs = tokenizer.batch_decode(
615
- output_ids, skip_special_tokens=True
616
- )[0]
617
 
618
- outputs = outputs.strip()
 
 
619
 
620
  return outputs, generation_time
621
-
622
 
623
- AutoConfig.register("tinyllava", TinyLlavaConfig)
624
- AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)
 
 
1
  # For licensing see accompanying LICENSE file.
2
  # Copyright (C) 2024 TinyLLaVA. All Rights Reserved.
3
  import time
4
+ import numpy as np
5
 
6
  import dataclasses
7
  from enum import auto, Enum
 
161
  return new_images
162
 
163
 
164
+ def tokenizer_image_token(prompts, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
165
+ def process_single_prompt(prompt):
166
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
167
 
168
+ def insert_separator(X, sep):
169
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
170
 
171
+ input_ids = []
172
+ offset = 0
173
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
174
+ offset = 1
175
+ input_ids.append(prompt_chunks[0][0])
176
 
177
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
178
+ input_ids.extend(x[offset:])
179
+
180
+ return input_ids
181
+
182
+ if isinstance(prompts, str): # Handle single prompt
183
+ return process_single_prompt(prompts)
184
+
185
+ # Handle batch of prompts
186
+ batch_input_ids = [process_single_prompt(prompt) for prompt in prompts]
187
 
188
  if return_tensors is not None:
189
  if return_tensors == 'pt':
190
+ max_length = max(len(ids) for ids in batch_input_ids)
191
+ padded_input_ids = [
192
+ ids + [tokenizer.pad_token_id] * (max_length - len(ids)) for ids in batch_input_ids
193
+ ]
194
+ return torch.tensor(padded_input_ids, dtype=torch.long)
195
  raise ValueError(f'Unsupported tensor type: {return_tensors}')
196
+
197
+ return batch_input_ids
198
 
199
  def load_image(image_file):
200
  if image_file.startswith("http") or image_file.startswith("https"):
 
219
  for _ in range(1, mlp_depth):
220
  modules.append(ACT_TYPE[act_type]())
221
  modules.append(nn.Linear(config.hidden_size, config.hidden_size))
222
+
223
  self._connector = nn.Sequential(*modules)
224
+
225
  def forward(self, x):
226
  return self._connector(x)
227
 
 
234
  else:
235
  self._vision_tower = SiglipVisionModel(cfg)
236
  self._image_processor = SiglipImageProcessor.from_pretrained(cfg.model_name_or_path)
237
+
238
  self.config = cfg
239
+
240
  def forward(self, x, **kwargs):
241
  image_features = self._vision_tower(x, output_hidden_states=True)
242
  image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
 
249
  raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
250
 
251
  return image_features
252
+
253
  @property
254
  def vision_tower(self):
255
  return self._vision_tower
256
+
257
  @vision_tower.setter
258
  def vision_tower(self, vision_tower):
259
  self._vision_tower = vision_tower
 
263
  return kwargs.pop(name)
264
  else:
265
  return None
266
+
267
 
268
  class TinyLlavaPreTrainedModel(PreTrainedModel):
269
  config_class = TinyLlavaConfig
 
299
 
300
  class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
301
  def __init__(self, config: TinyLlavaConfig):
302
+
303
  super().__init__(config)
304
 
305
  self.language_model = PhiForCausalLM(config.text_config)
 
307
  self.connector = Connector(config)
308
  self.post_init()
309
 
310
+
311
  def get_input_embeddings(self):
312
  return self.language_model.get_input_embeddings()
313
 
 
337
  self.vocab_size = model_embeds.num_embeddings
338
  return model_embeds
339
 
340
+
341
  def forward(
342
  self,
343
  input_ids: torch.LongTensor = None,
 
383
  output_hidden_states=output_hidden_states,
384
  return_dict=return_dict
385
  )
386
+
387
  @torch.no_grad()
388
  def generate(
389
  self,
 
423
  inputs_embeds=inputs_embeds,
424
  **kwargs
425
  )
426
+
427
  def encode_images(self, images):
428
  kwargs = {}
429
  kwargs['vision_feature_layer'] = self.config.vision_feature_layer
 
432
  image_features = self.vision_tower(images, **kwargs)
433
  image_features = self.connector(image_features)
434
  return image_features
435
+
436
+
437
+
438
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
439
  inputs_embeds=None, **kwargs):
440
  images = kwargs.pop("images", None)
 
447
  if image_sizes is not None:
448
  inputs['image_sizes'] = image_sizes
449
  return inputs
450
+
451
  def prepare_inputs_labels_for_multimodal(
452
  self, input_ids, position_ids, attention_mask, past_key_values, labels,
453
  images, image_sizes=None
 
456
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
457
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
458
 
459
+
460
  image_features = self.encode_images(images)
461
 
462
  # TODO: image start / end is not implemented here to support pretraining.
 
580
  position_ids = None
581
 
582
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
583
+
584
  def chat(
585
  self,
586
+ prompts: Union[list, str],
587
+ tokenizer=None,
588
+ images: Union[list, str] = None,
589
  max_new_tokens: int = 512,
590
+ num_beams=1,
591
  top_p=None,
592
  temperature=0
593
  ):
594
+ """
595
+ Generate responses for a batch of prompts.
596
+
597
+ Args:
598
+ prompts (list): List of text prompts.
599
+ tokenizer: Tokenizer object.
600
+ images (list, optional): List of image file paths corresponding to the prompts. Defaults to None.
601
+ max_new_tokens (int): Maximum number of new tokens to generate. Defaults to 512.
602
+ num_beams (int): Number of beams for beam search. Defaults to 1.
603
+ top_p (float, optional): Nucleus sampling probability. Defaults to None.
604
+ temperature (float): Sampling temperature. Defaults to 0.
605
+
606
+ Returns:
607
+ list: List of generated outputs.
608
+ list: List of generation times for each batch.
609
+ """
610
+ if isinstance(prompts, list) and isinstance(images, str):
611
+ assert len(prompts) == len(images) or images is None, "Mismatch between prompts and images."
612
+ else:
613
+ prompts = [prompts]
614
+ images = [images]
615
  image_processor = self.vision_tower._image_processor
616
 
617
+ # Prepare inputs
618
+ input_texts = []
619
+ image_tensors = None
620
+
621
+ for i, prompt in enumerate(prompts):
622
+ if images and images[i] is not None:
623
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
624
+ conv = conv_phi_v0.copy()
625
+
626
+ conv.append_message(conv.roles[0], prompt)
627
+ conv.append_message(conv.roles[1], None)
628
+ input_texts.append(conv.get_prompt())
629
+
630
+ # Tokenize prompts
631
+ input_ids = tokenizer_image_token(
632
+ input_texts, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
633
+ ).to(self.device)
634
+
635
+ # Process images
636
+ if images:
637
+ processed_images = [
638
+ process_images(load_image(image), image_processor, self.config)
639
+ for image in images if image is not None
640
+ ]
641
+ image_tensors = torch.stack(processed_images).to(self.device).squeeze(1)
642
+
643
+ # Generate responses
644
  stime = time.time()
 
645
  with torch.inference_mode():
646
  output_ids = self.generate(
647
  input_ids,
648
+ images=image_tensors,
649
  do_sample=True if temperature > 0 else False,
650
  temperature=temperature,
651
  top_p=top_p,
 
653
  pad_token_id=tokenizer.pad_token_id,
654
  max_new_tokens=max_new_tokens,
655
  use_cache=True,
 
656
  )
 
 
657
  generation_time = time.time() - stime
 
 
 
658
 
659
+ # Decode outputs
660
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
661
+ outputs = [output.strip() for output in outputs]
662
 
663
  return outputs, generation_time
 
664
 
665
+
666
+ AutoConfig.register("tinyllava", TinyLlavaConfig)
667
+ AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)