jonathanjordan21 commited on
Commit
22fcea7
·
verified ·
1 Parent(s): 406e6b9

Update modeling_qwen2_nomic_vision.py

Browse files
Files changed (1) hide show
  1. modeling_qwen2_nomic_vision.py +13 -3
modeling_qwen2_nomic_vision.py CHANGED
@@ -770,6 +770,13 @@ QWEN2_INPUTS_DOCSTRING = r"""
770
  """
771
 
772
 
 
 
 
 
 
 
 
773
  # @add_start_docstrings(
774
  # "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
775
  # QWEN2_START_DOCSTRING,
@@ -824,7 +831,7 @@ class Qwen2NomicVisionModel(Qwen2NomicVisionPreTrainedModel):
824
  return_dict: Optional[bool] = None,
825
  cache_position: Optional[torch.LongTensor] = None,
826
  image = None,
827
- ) -> Union[Tuple, BaseModelOutputWithPast]:
828
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
829
  output_hidden_states = (
830
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -936,12 +943,15 @@ class Qwen2NomicVisionModel(Qwen2NomicVisionPreTrainedModel):
936
  next_cache = next_cache.to_legacy_cache()
937
 
938
  if not return_dict:
 
 
939
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
940
- return BaseModelOutputWithPast(
941
  last_hidden_state=hidden_states,
942
  past_key_values=next_cache,
943
  hidden_states=all_hidden_states,
944
  attentions=all_self_attns,
 
945
  )
946
 
947
  # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
@@ -1196,7 +1206,7 @@ class Qwen2NomicVisionForCausalLM(Qwen2NomicVisionPreTrainedModel, GenerationMix
1196
 
1197
  loss = None
1198
  if labels is not None:
1199
- loss = self.loss_function(logits[:, 1:], labels, self.vocab_size, **loss_kwargs)
1200
 
1201
  if not return_dict:
1202
  output = (logits,) + outputs[1:]
 
770
  """
771
 
772
 
773
+ @dataclass
774
+ class Qwen2NomicVisionOutput(BaseModelOutputWithPast):
775
+ last_hidden_state: FloatTensor = None
776
+ past_key_values: Optional = None
777
+ hidden_states: Optional = None
778
+ attentions: Optional = None
779
+ processed_image: Optional = None
780
  # @add_start_docstrings(
781
  # "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
782
  # QWEN2_START_DOCSTRING,
 
831
  return_dict: Optional[bool] = None,
832
  cache_position: Optional[torch.LongTensor] = None,
833
  image = None,
834
+ ) -> Union[Tuple, Qwen2NomicVisionOutput]:
835
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
836
  output_hidden_states = (
837
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
943
  next_cache = next_cache.to_legacy_cache()
944
 
945
  if not return_dict:
946
+ if image != None:
947
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, mix] if v is not None)
948
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
949
+ return Qwen2NomicVisionOutput(
950
  last_hidden_state=hidden_states,
951
  past_key_values=next_cache,
952
  hidden_states=all_hidden_states,
953
  attentions=all_self_attns,
954
+ processed_image=mix,
955
  )
956
 
957
  # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
 
1206
 
1207
  loss = None
1208
  if labels is not None:
1209
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1210
 
1211
  if not return_dict:
1212
  output = (logits,) + outputs[1:]