SmerkyG commited on
Commit
7bafffb
·
verified ·
1 Parent(s): d7233f2

Update modeling_rwkv6qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv6qwen2.py +34 -120
modeling_rwkv6qwen2.py CHANGED
@@ -834,126 +834,40 @@ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
834
  attentions=outputs.attentions,
835
  )
836
 
837
- # def prepare_inputs_for_generation(
838
- # self,
839
- # input_ids: torch.LongTensor,
840
- # past_key_values: Optional[Cache] = None,
841
- # attention_mask: Optional[torch.LongTensor] = None,
842
- # inputs_embeds: Optional[torch.FloatTensor] = None,
843
- # cache_position: Optional[torch.LongTensor] = None,
844
- # **kwargs,
845
- # ):
846
- # """
847
- # Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
848
- # slicing inputs given the existing cache.
849
-
850
- # See the forward pass in the model documentation for expected arguments (different models might have different
851
- # requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
852
- # """
853
-
854
- # # 1. Handle BC:
855
- # model_inputs = {}
856
- # # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
857
- # if self._supports_cache_class:
858
- # model_inputs["cache_position"] = cache_position
859
- # # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
860
- # # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
861
- # # (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
862
- # elif cache_position is None:
863
- # past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
864
- # cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
865
-
866
- # # 2. Generic cache-dependent input preparation
867
- # # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
868
- # # Exception 1: when passing input_embeds, input_ids may be missing entries
869
- # # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
870
- # # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
871
- # if past_key_values is not None:
872
- # model_inputs["past_key_values"] = past_key_values
873
- # if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
874
- # input_ids = input_ids[:, -cache_position.shape[0] :]
875
- # elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
876
- # input_ids = input_ids[:, cache_position]
877
-
878
- # # 3. Prepare base model inputs
879
- # input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
880
- # # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
881
- # if not self.config.is_encoder_decoder:
882
- # if inputs_embeds is not None and cache_position[0] == 0:
883
- # model_inputs[input_ids_key] = None
884
- # model_inputs["inputs_embeds"] = inputs_embeds
885
- # else:
886
- # # `clone` calls in this function ensure a consistent stride. See #32227
887
- # model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
888
- # model_inputs["inputs_embeds"] = None
889
- # else:
890
- # model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
891
-
892
- # # 4. Create missing `position_ids` on the fly
893
- # if (attention_mask is not None and kwargs.get("position_ids") is None and "position_ids" in set(inspect.signature(self.forward).parameters.keys())):
894
- # position_ids = attention_mask.long().cumsum(-1) - 1
895
- # position_ids.masked_fill_(attention_mask == 0, 1)
896
- # kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
897
-
898
- # # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
899
- # for model_input_name in ["position_ids", "token_type_ids"]:
900
- # model_input = kwargs.get(model_input_name)
901
- # if model_input is not None:
902
- # if past_key_values:
903
- # model_input = model_input[:, -input_ids.shape[1] :]
904
- # model_input = model_input.clone(memory_format=torch.contiguous_format)
905
- # model_inputs[model_input_name] = model_input
906
-
907
- # # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
908
- # if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
909
- # if model_inputs["inputs_embeds"] is not None:
910
- # batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
911
- # device = model_inputs["inputs_embeds"].device
912
- # else:
913
- # batch_size, sequence_length = model_inputs[input_ids_key].shape
914
- # device = model_inputs[input_ids_key].device
915
-
916
- # # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
917
- # # the 4D causal mask exists, it should be present in the base model (XXXModel class).
918
- # base_model = getattr(self, self.base_model_prefix, None)
919
- # if base_model is None:
920
- # causal_mask_creation_function = getattr(
921
- # self, "_prepare_4d_causal_attention_mask_with_cache_position", None
922
- # )
923
- # else:
924
- # causal_mask_creation_function = getattr(
925
- # base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
926
- # )
927
- # if causal_mask_creation_function is None:
928
- # logger.warning_once(
929
- # f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
930
- # "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
931
- # "writing code, see Llama for an example implementation. If you're a user, please report this "
932
- # "issue on GitHub."
933
- # )
934
- # else:
935
- # attention_mask = causal_mask_creation_function(
936
- # attention_mask,
937
- # sequence_length=sequence_length,
938
- # target_length=past_key_values.get_max_cache_shape(),
939
- # dtype=self.dtype,
940
- # device=device,
941
- # cache_position=cache_position,
942
- # batch_size=batch_size,
943
- # config=self.config,
944
- # past_key_values=past_key_values,
945
- # )
946
- # if attention_mask is not None:
947
- # model_inputs["attention_mask"] = attention_mask
948
-
949
- # # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
950
- # for key, value in kwargs.items():
951
- # if key not in model_inputs:
952
- # model_inputs[key] = value
953
-
954
- # # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
955
- # model_inputs.pop("labels", None)
956
- # return model_inputs
957
 
958
  @add_start_docstrings(
959
  """
 
834
  attentions=outputs.attentions,
835
  )
836
 
837
+ def prepare_inputs_for_generation(
838
+ self,
839
+ input_ids: torch.LongTensor,
840
+ past_key_values: Optional[Cache] = None,
841
+ attention_mask: Optional[torch.LongTensor] = None,
842
+ inputs_embeds: Optional[torch.FloatTensor] = None,
843
+ cache_position: Optional[torch.LongTensor] = None,
844
+ **kwargs,
845
+ ):
846
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
847
+ if past_key_values is not None and len(past_key_values) > 0:
848
+ input_ids = input_ids[:, -1:]
849
+
850
+ model_inputs = {
851
+ 'past_key_values': past_key_values,
852
+ 'attention_mask': attention_mask,
853
+ 'cache_position': cache_position,
854
+ }
855
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
856
+ if inputs_embeds is not None and past_key_values is None:
857
+ model_inputs['inputs_embeds'] = inputs_embeds
858
+ else:
859
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
860
+ # recompiles graphs as the stride of the inputs is a guard.
861
+ # Ref: https://github.com/huggingface/transformers/pull/29114
862
+ # TODO: use `next_tokens` directly instead.
863
+ model_inputs['input_ids'] = input_ids.contiguous()
864
+
865
+ model_inputs.update(**kwargs)
866
+
867
+ # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
868
+ model_inputs.pop("labels", None)
869
+
870
+ return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871
 
872
  @add_start_docstrings(
873
  """