Update modeling_rwkv6qwen2.py
Browse files- modeling_rwkv6qwen2.py +34 -120
modeling_rwkv6qwen2.py
CHANGED
@@ -834,126 +834,40 @@ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
|
|
834 |
attentions=outputs.attentions,
|
835 |
)
|
836 |
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
# if past_key_values is not None:
|
872 |
-
# model_inputs["past_key_values"] = past_key_values
|
873 |
-
# if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
|
874 |
-
# input_ids = input_ids[:, -cache_position.shape[0] :]
|
875 |
-
# elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
876 |
-
# input_ids = input_ids[:, cache_position]
|
877 |
-
|
878 |
-
# # 3. Prepare base model inputs
|
879 |
-
# input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
880 |
-
# # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
881 |
-
# if not self.config.is_encoder_decoder:
|
882 |
-
# if inputs_embeds is not None and cache_position[0] == 0:
|
883 |
-
# model_inputs[input_ids_key] = None
|
884 |
-
# model_inputs["inputs_embeds"] = inputs_embeds
|
885 |
-
# else:
|
886 |
-
# # `clone` calls in this function ensure a consistent stride. See #32227
|
887 |
-
# model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
888 |
-
# model_inputs["inputs_embeds"] = None
|
889 |
-
# else:
|
890 |
-
# model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
891 |
-
|
892 |
-
# # 4. Create missing `position_ids` on the fly
|
893 |
-
# if (attention_mask is not None and kwargs.get("position_ids") is None and "position_ids" in set(inspect.signature(self.forward).parameters.keys())):
|
894 |
-
# position_ids = attention_mask.long().cumsum(-1) - 1
|
895 |
-
# position_ids.masked_fill_(attention_mask == 0, 1)
|
896 |
-
# kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
|
897 |
-
|
898 |
-
# # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
|
899 |
-
# for model_input_name in ["position_ids", "token_type_ids"]:
|
900 |
-
# model_input = kwargs.get(model_input_name)
|
901 |
-
# if model_input is not None:
|
902 |
-
# if past_key_values:
|
903 |
-
# model_input = model_input[:, -input_ids.shape[1] :]
|
904 |
-
# model_input = model_input.clone(memory_format=torch.contiguous_format)
|
905 |
-
# model_inputs[model_input_name] = model_input
|
906 |
-
|
907 |
-
# # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
|
908 |
-
# if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
909 |
-
# if model_inputs["inputs_embeds"] is not None:
|
910 |
-
# batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
911 |
-
# device = model_inputs["inputs_embeds"].device
|
912 |
-
# else:
|
913 |
-
# batch_size, sequence_length = model_inputs[input_ids_key].shape
|
914 |
-
# device = model_inputs[input_ids_key].device
|
915 |
-
|
916 |
-
# # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
917 |
-
# # the 4D causal mask exists, it should be present in the base model (XXXModel class).
|
918 |
-
# base_model = getattr(self, self.base_model_prefix, None)
|
919 |
-
# if base_model is None:
|
920 |
-
# causal_mask_creation_function = getattr(
|
921 |
-
# self, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
922 |
-
# )
|
923 |
-
# else:
|
924 |
-
# causal_mask_creation_function = getattr(
|
925 |
-
# base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
926 |
-
# )
|
927 |
-
# if causal_mask_creation_function is None:
|
928 |
-
# logger.warning_once(
|
929 |
-
# f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
930 |
-
# "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
931 |
-
# "writing code, see Llama for an example implementation. If you're a user, please report this "
|
932 |
-
# "issue on GitHub."
|
933 |
-
# )
|
934 |
-
# else:
|
935 |
-
# attention_mask = causal_mask_creation_function(
|
936 |
-
# attention_mask,
|
937 |
-
# sequence_length=sequence_length,
|
938 |
-
# target_length=past_key_values.get_max_cache_shape(),
|
939 |
-
# dtype=self.dtype,
|
940 |
-
# device=device,
|
941 |
-
# cache_position=cache_position,
|
942 |
-
# batch_size=batch_size,
|
943 |
-
# config=self.config,
|
944 |
-
# past_key_values=past_key_values,
|
945 |
-
# )
|
946 |
-
# if attention_mask is not None:
|
947 |
-
# model_inputs["attention_mask"] = attention_mask
|
948 |
-
|
949 |
-
# # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
950 |
-
# for key, value in kwargs.items():
|
951 |
-
# if key not in model_inputs:
|
952 |
-
# model_inputs[key] = value
|
953 |
-
|
954 |
-
# # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
955 |
-
# model_inputs.pop("labels", None)
|
956 |
-
# return model_inputs
|
957 |
|
958 |
@add_start_docstrings(
|
959 |
"""
|
|
|
834 |
attentions=outputs.attentions,
|
835 |
)
|
836 |
|
837 |
+
def prepare_inputs_for_generation(
|
838 |
+
self,
|
839 |
+
input_ids: torch.LongTensor,
|
840 |
+
past_key_values: Optional[Cache] = None,
|
841 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
842 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
843 |
+
cache_position: Optional[torch.LongTensor] = None,
|
844 |
+
**kwargs,
|
845 |
+
):
|
846 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
847 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
848 |
+
input_ids = input_ids[:, -1:]
|
849 |
+
|
850 |
+
model_inputs = {
|
851 |
+
'past_key_values': past_key_values,
|
852 |
+
'attention_mask': attention_mask,
|
853 |
+
'cache_position': cache_position,
|
854 |
+
}
|
855 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
856 |
+
if inputs_embeds is not None and past_key_values is None:
|
857 |
+
model_inputs['inputs_embeds'] = inputs_embeds
|
858 |
+
else:
|
859 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
860 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
861 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
862 |
+
# TODO: use `next_tokens` directly instead.
|
863 |
+
model_inputs['input_ids'] = input_ids.contiguous()
|
864 |
+
|
865 |
+
model_inputs.update(**kwargs)
|
866 |
+
|
867 |
+
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
868 |
+
model_inputs.pop("labels", None)
|
869 |
+
|
870 |
+
return model_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
871 |
|
872 |
@add_start_docstrings(
|
873 |
"""
|