d-Matrix commited on
Commit
0c0c1c7
·
verified ·
1 Parent(s): d019b6f

Update modeling_gemma.py

Browse files
Files changed (1) hide show
  1. modeling_gemma.py +418 -163
modeling_gemma.py CHANGED
@@ -14,6 +14,7 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  """ PyTorch Gemma model."""
 
17
  import math
18
  import warnings
19
  from typing import List, Optional, Tuple, Union
@@ -24,15 +25,20 @@ import torch.utils.checkpoint
24
  from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
 
27
- from ...activations import ACT2FN
28
- from ...cache_utils import Cache, DynamicCache, StaticCache
29
- from ...modeling_attn_mask_utils import (
 
30
  _prepare_4d_causal_attention_mask,
31
  )
32
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
33
- from ...modeling_utils import PreTrainedModel
34
- from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
35
- from ...utils import (
 
 
 
 
36
  add_start_docstrings,
37
  add_start_docstrings_to_model_forward,
38
  is_flash_attn_2_available,
@@ -40,7 +46,7 @@ from ...utils import (
40
  logging,
41
  replace_return_docstrings,
42
  )
43
- from ...utils.import_utils import is_torch_fx_available
44
  from .configuration_gemma import GemmaConfig
45
 
46
 
@@ -48,7 +54,7 @@ if is_flash_attn_2_available():
48
  from flash_attn import flash_attn_func, flash_attn_varlen_func
49
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
50
 
51
-
52
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
53
  # It means that the function will not be traced through and simply appear as a node in the graph.
54
  if is_torch_fx_available():
@@ -67,7 +73,9 @@ def _get_unpad_data(attention_mask):
67
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
  max_seqlen_in_batch = seqlens_in_batch.max().item()
70
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
71
  return (
72
  indices,
73
  cu_seqlens,
@@ -75,20 +83,6 @@ def _get_unpad_data(attention_mask):
75
  )
76
 
77
 
78
- class GemmaRMSNorm(nn.Module):
79
- def __init__(self, dim: int, eps: float = 1e-6):
80
- super().__init__()
81
- self.eps = eps
82
- self.weight = nn.Parameter(torch.zeros(dim))
83
-
84
- def _norm(self, x):
85
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
86
-
87
- def forward(self, x):
88
- output = self._norm(x.float()).type_as(x)
89
- return output.to(self.weight.device) * (1 + self.weight)
90
-
91
-
92
  ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
93
 
94
 
@@ -107,16 +101,30 @@ class GemmaRotaryEmbedding(nn.Module):
107
  # x: [bs, num_attention_heads, seq_len, head_size]
108
  if self.inv_freq is None:
109
  self.inv_freq = 1.0 / (
110
- self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
 
 
 
 
 
 
111
  )
112
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
113
- position_ids_expanded = position_ids[:, None, :].float().to(x.device)
 
 
114
  # Force float32 since bfloat16 loses precision on long contexts
115
  # See https://github.com/huggingface/transformers/pull/29285
116
  device_type = x.device.type
117
- device_type = device_type if isinstance(device_type, str) else "cpu"
 
 
 
 
118
  with torch.autocast(device_type=device_type, enabled=False):
119
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 
 
120
  emb = torch.cat((freqs, freqs), dim=-1)
121
  cos = emb.cos()
122
  sin = emb.sin()
@@ -159,7 +167,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
159
  return q_embed, k_embed
160
 
161
 
162
- # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemma
163
  class GemmaMLP(nn.Module):
164
  def __init__(self, config):
165
  super().__init__()
@@ -169,7 +176,18 @@ class GemmaMLP(nn.Module):
169
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
170
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
171
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
172
- self.act_fn = ACT2FN[config.hidden_act]
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  def forward(self, x):
175
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
@@ -184,7 +202,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
184
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
185
  if n_rep == 1:
186
  return hidden_states
187
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
188
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
189
 
190
 
@@ -219,10 +239,22 @@ class GemmaAttention(nn.Module):
219
  f" and `num_heads`: {self.num_heads})."
220
  )
221
 
222
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
223
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
224
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
225
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
 
 
 
 
 
 
 
 
 
 
 
 
226
  self.rotary_emb = GemmaRotaryEmbedding(
227
  self.head_dim,
228
  max_position_embeddings=self.max_position_embeddings,
@@ -246,34 +278,47 @@ class GemmaAttention(nn.Module):
246
  key_states = self.k_proj(hidden_states)
247
  value_states = self.v_proj(hidden_states)
248
 
249
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
250
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
251
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
252
 
253
  past_key_value = getattr(self, "past_key_value", past_key_value)
254
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
255
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
256
 
257
  if past_key_value is not None:
258
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
259
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
260
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
261
 
262
  key_states = repeat_kv(key_states, self.num_key_value_groups)
263
  value_states = repeat_kv(value_states, self.num_key_value_groups)
264
 
265
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
266
 
267
  if attention_mask is not None: # no matter the length, we just slice it
268
- if cache_position is not None:
269
- causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
270
- else:
271
- causal_mask = attention_mask
272
  attn_weights = attn_weights + causal_mask
273
 
274
  # upcast attention to fp32
275
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
276
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
 
277
  attn_output = torch.matmul(attn_weights, value_states)
278
 
279
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -332,19 +377,29 @@ class GemmaFlashAttention2(GemmaAttention):
332
  # Flash attention requires the input to have the shape
333
  # batch_size x seq_length x head_dim x hidden_dim
334
  # therefore we just need to keep the original shape
335
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
336
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
337
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
338
 
339
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
340
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
341
 
342
  past_key_value = getattr(self, "past_key_value", past_key_value)
343
 
344
  if past_key_value is not None:
345
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
346
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
347
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
348
 
349
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
350
  # to be able to avoid many of these transpose/reshape/view.
@@ -381,7 +436,12 @@ class GemmaFlashAttention2(GemmaAttention):
381
  value_states = value_states.to(target_dtype)
382
 
383
  attn_output = self._flash_attention_forward(
384
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
 
 
 
 
 
385
  )
386
 
387
  attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -393,7 +453,14 @@ class GemmaFlashAttention2(GemmaAttention):
393
  return attn_output, attn_weights, past_key_value
394
 
395
  def _flash_attention_forward(
396
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
 
 
 
 
 
 
 
397
  ):
398
  """
399
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -409,7 +476,7 @@ class GemmaFlashAttention2(GemmaAttention):
409
  attention_mask (`torch.Tensor`):
410
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
411
  position of padding tokens and 1 for the position of non-padding tokens.
412
- dropout (`int`, *optional*):
413
  Attention dropout
414
  softmax_scale (`float`, *optional*):
415
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
@@ -423,7 +490,14 @@ class GemmaFlashAttention2(GemmaAttention):
423
  # Contains at least one padding token in the sequence
424
  if attention_mask is not None:
425
  batch_size = query_states.shape[0]
426
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
 
 
 
 
 
 
 
427
  query_states, key_states, value_states, attention_mask, query_length
428
  )
429
 
@@ -443,27 +517,39 @@ class GemmaFlashAttention2(GemmaAttention):
443
  causal=causal,
444
  )
445
 
446
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
 
447
  else:
448
  attn_output = flash_attn_func(
449
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
 
 
 
 
 
450
  )
451
 
452
  return attn_output
453
 
454
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
 
 
455
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
456
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
457
 
458
  key_layer = index_first_axis(
459
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
460
  )
461
  value_layer = index_first_axis(
462
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
463
  )
464
  if query_length == kv_seq_len:
465
  query_layer = index_first_axis(
466
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
 
467
  )
468
  cu_seqlens_q = cu_seqlens_k
469
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -478,7 +564,9 @@ class GemmaFlashAttention2(GemmaAttention):
478
  else:
479
  # The -q_len: slice assumes left padding.
480
  attention_mask = attention_mask[:, -query_length:]
481
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
 
 
482
 
483
  return (
484
  query_layer,
@@ -531,40 +619,53 @@ class GemmaSdpaAttention(GemmaAttention):
531
  key_states = self.k_proj(hidden_states)
532
  value_states = self.v_proj(hidden_states)
533
 
534
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
535
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
536
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
537
 
538
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
539
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
540
 
541
  past_key_value = getattr(self, "past_key_value", past_key_value)
542
 
543
  if past_key_value is not None:
544
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
545
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
546
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
547
 
548
  key_states = repeat_kv(key_states, self.num_key_value_groups)
549
  value_states = repeat_kv(value_states, self.num_key_value_groups)
550
 
551
  causal_mask = attention_mask
552
- if attention_mask is not None and cache_position is not None:
553
- causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
554
 
555
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
556
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
557
- if query_states.device.type == "cuda" and causal_mask is not None:
558
  query_states = query_states.contiguous()
559
  key_states = key_states.contiguous()
560
  value_states = value_states.contiguous()
561
 
 
 
562
  attn_output = torch.nn.functional.scaled_dot_product_attention(
563
  query_states,
564
  key_states,
565
  value_states,
566
- attn_mask=causal_mask.to(query_states.device),
567
  dropout_p=self.attention_dropout if self.training else 0.0,
 
568
  )
569
 
570
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -588,11 +689,15 @@ class GemmaDecoderLayer(nn.Module):
588
  super().__init__()
589
  self.hidden_size = config.hidden_size
590
 
591
- self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
 
 
592
 
593
  self.mlp = GemmaMLP(config)
594
  self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
595
- self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
596
 
597
  def forward(
598
  self,
@@ -604,7 +709,9 @@ class GemmaDecoderLayer(nn.Module):
604
  use_cache: Optional[bool] = False,
605
  cache_position: Optional[torch.LongTensor] = None,
606
  **kwargs,
607
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
608
  """
609
  Args:
610
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -639,13 +746,14 @@ class GemmaDecoderLayer(nn.Module):
639
  cache_position=cache_position,
640
  **kwargs,
641
  )
642
- hidden_states = residual.to(hidden_states.device) + hidden_states
643
 
644
  # Fully Connected
645
  residual = hidden_states
646
  hidden_states = self.post_attention_layernorm(hidden_states)
647
  hidden_states = self.mlp(hidden_states)
648
  hidden_states = residual + hidden_states
 
649
  outputs = (hidden_states,)
650
 
651
  if output_attentions:
@@ -700,21 +808,26 @@ class GemmaPreTrainedModel(PreTrainedModel):
700
  if module.padding_idx is not None:
701
  module.weight.data[module.padding_idx].zero_()
702
 
703
- def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
704
- if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
 
 
 
 
 
705
  raise ValueError(
706
  "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
707
  "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
708
  )
709
 
710
- if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
711
- causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
712
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
713
-
714
  for layer in self.model.layers:
715
  weights = layer.self_attn.o_proj.weight
716
  layer.self_attn.past_key_value = cache_cls(
717
- self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
 
 
 
 
718
  )
719
 
720
  def _reset_cache(self):
@@ -789,6 +902,10 @@ GEMMA_INPUTS_DOCSTRING = r"""
789
  more detail.
790
  return_dict (`bool`, *optional*):
791
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
 
 
 
 
792
  """
793
 
794
 
@@ -810,19 +927,18 @@ class GemmaModel(GemmaPreTrainedModel):
810
  self.padding_idx = config.pad_token_id
811
  self.vocab_size = config.vocab_size
812
 
813
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
814
  self.layers = nn.ModuleList(
815
- [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
 
 
816
  )
817
  self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
818
  self.gradient_checkpointing = False
819
 
820
- # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
821
- # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
822
- causal_mask = torch.full(
823
- (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
824
- )
825
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
826
  # Initialize weights and apply final processing
827
  self.post_init()
828
 
@@ -847,12 +963,20 @@ class GemmaModel(GemmaPreTrainedModel):
847
  return_dict: Optional[bool] = None,
848
  cache_position: Optional[torch.LongTensor] = None,
849
  ) -> Union[Tuple, BaseModelOutputWithPast]:
850
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
851
  output_hidden_states = (
852
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
853
  )
854
  use_cache = use_cache if use_cache is not None else self.config.use_cache
855
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
856
 
857
  if (input_ids is None) ^ (inputs_embeds is not None):
858
  raise ValueError(
@@ -876,19 +1000,28 @@ class GemmaModel(GemmaPreTrainedModel):
876
 
877
  if cache_position is None:
878
  cache_position = torch.arange(
879
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
880
  )
881
 
882
  if position_ids is None:
883
  position_ids = cache_position.unsqueeze(0)
884
 
885
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
 
 
886
 
887
  # embed positions
888
  hidden_states = inputs_embeds
889
 
890
  # normalized
891
- hidden_states = hidden_states * (self.config.hidden_size**0.5)
 
 
 
 
 
892
 
893
  # decoder layers
894
  all_hidden_states = () if output_hidden_states else None
@@ -938,10 +1071,16 @@ class GemmaModel(GemmaPreTrainedModel):
938
  next_cache = None
939
  if use_cache:
940
  next_cache = (
941
- next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
 
 
942
  )
943
  if not return_dict:
944
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
945
  return BaseModelOutputWithPast(
946
  last_hidden_state=hidden_states,
947
  past_key_values=next_cache,
@@ -949,43 +1088,102 @@ class GemmaModel(GemmaPreTrainedModel):
949
  attentions=all_self_attns,
950
  )
951
 
952
- def _update_causal_mask(self, attention_mask, input_tensor):
 
 
 
 
 
 
 
 
 
 
 
953
  if self.config._attn_implementation == "flash_attention_2":
954
  if attention_mask is not None and 0.0 in attention_mask:
955
  return attention_mask
956
  return None
957
 
958
- batch_size, seq_length = input_tensor.shape[:2]
959
- dtype = input_tensor.dtype
960
- device = input_tensor.device
961
-
962
- # support going beyond cached `max_position_embedding`
963
- if seq_length > self.causal_mask.shape[-1]:
964
- causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
965
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
 
966
 
967
- # We use the current dtype to avoid any overflows
968
  min_dtype = torch.finfo(dtype).min
969
- causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
970
-
971
- causal_mask = causal_mask.to(dtype=dtype, device=device)
972
- if attention_mask is not None and attention_mask.dim() == 2:
973
- mask_length = attention_mask.shape[-1]
974
- padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
975
- causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
976
-
977
- if self.config._attn_implementation == "sdpa" and attention_mask is not None:
978
- # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
979
- is_tracing = (
980
- torch.jit.is_tracing()
981
- or isinstance(input_tensor, torch.fx.Proxy)
982
- or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983
  )
984
- if not is_tracing and torch.any(attention_mask != 1):
985
- # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
986
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
987
- # Details: https://github.com/pytorch/pytorch/issues/110213
988
- causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
989
 
990
  return causal_mask
991
 
@@ -1023,7 +1221,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1023
 
1024
  # Ignore copy
1025
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
1026
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
1027
  def forward(
1028
  self,
1029
  input_ids: torch.LongTensor = None,
@@ -1063,11 +1263,19 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1063
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1064
  "What is your favorite condiment?"
1065
  ```"""
1066
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1067
  output_hidden_states = (
1068
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
1069
  )
1070
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1071
 
1072
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1073
  outputs = self.model(
@@ -1112,14 +1320,44 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1112
  )
1113
 
1114
  def prepare_inputs_for_generation(
1115
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
 
 
 
 
 
1116
  ):
 
 
 
 
 
 
 
 
 
1117
  past_length = 0
1118
  if past_key_values is not None:
1119
  if isinstance(past_key_values, Cache):
1120
- cache_length = past_key_values.get_seq_length()
1121
- past_length = past_key_values.seen_tokens
1122
- max_cache_length = past_key_values.get_max_length()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1123
  else:
1124
  cache_length = past_length = past_key_values[0][0].shape[2]
1125
  max_cache_length = None
@@ -1128,7 +1366,10 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1128
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1129
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1130
  # input)
1131
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
 
 
 
1132
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1133
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1134
  # input_ids based on the past_length.
@@ -1152,20 +1393,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1152
  if past_key_values:
1153
  position_ids = position_ids[:, -input_ids.shape[1] :]
1154
 
1155
- if self.generation_config.cache_implementation == "static":
1156
- # generation with static cache
1157
- cache_position = kwargs.get("cache_position", None)
1158
- if cache_position is None:
1159
- past_length = 0
1160
- else:
1161
- past_length = cache_position[-1] + 1
1162
- input_ids = input_ids[:, past_length:]
1163
- position_ids = position_ids[:, past_length:]
1164
-
1165
- # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1166
- # same goes for position ids. Could also help with continued generation.
1167
- cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1168
-
1169
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1170
  if inputs_embeds is not None and past_key_values is None:
1171
  model_inputs = {"inputs_embeds": inputs_embeds}
@@ -1175,9 +1402,22 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1175
  # TODO: use `next_tokens` directly instead.
1176
  model_inputs = {"input_ids": input_ids.contiguous()}
1177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1178
  model_inputs.update(
1179
  {
1180
- "position_ids": position_ids.contiguous(),
1181
  "cache_position": cache_position,
1182
  "past_key_values": past_key_values,
1183
  "use_cache": kwargs.get("use_cache"),
@@ -1191,7 +1431,10 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1191
  reordered_past = ()
1192
  for layer_past in past_key_values:
1193
  reordered_past += (
1194
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
 
 
 
1195
  )
1196
  return reordered_past
1197
 
@@ -1248,7 +1491,9 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1248
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1249
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1250
  """
1251
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1252
 
1253
  transformer_outputs = self.model(
1254
  input_ids,
@@ -1270,19 +1515,25 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1270
  batch_size = inputs_embeds.shape[0]
1271
 
1272
  if self.config.pad_token_id is None and batch_size != 1:
1273
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
 
 
1274
  if self.config.pad_token_id is None:
1275
  sequence_lengths = -1
1276
  else:
1277
  if input_ids is not None:
1278
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1279
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
 
 
1280
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1281
  sequence_lengths = sequence_lengths.to(logits.device)
1282
  else:
1283
  sequence_lengths = -1
1284
 
1285
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1286
 
1287
  loss = None
1288
  if labels is not None:
@@ -1290,7 +1541,9 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1290
  if self.config.problem_type is None:
1291
  if self.num_labels == 1:
1292
  self.config.problem_type = "regression"
1293
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
1294
  self.config.problem_type = "single_label_classification"
1295
  else:
1296
  self.config.problem_type = "multi_label_classification"
@@ -1303,7 +1556,9 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1303
  loss = loss_fct(pooled_logits, labels)
1304
  elif self.config.problem_type == "single_label_classification":
1305
  loss_fct = CrossEntropyLoss()
1306
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1307
  elif self.config.problem_type == "multi_label_classification":
1308
  loss_fct = BCEWithLogitsLoss()
1309
  loss = loss_fct(pooled_logits, labels)
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  """ PyTorch Gemma model."""
17
+
18
  import math
19
  import warnings
20
  from typing import List, Optional, Tuple, Union
 
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
+ from transformers.modeling_attn_mask_utils import (
31
+ AttentionMaskConverter,
32
  _prepare_4d_causal_attention_mask,
33
  )
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPast,
36
+ CausalLMOutputWithPast,
37
+ SequenceClassifierOutputWithPast,
38
+ )
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
41
+ from transformers.utils import (
42
  add_start_docstrings,
43
  add_start_docstrings_to_model_forward,
44
  is_flash_attn_2_available,
 
46
  logging,
47
  replace_return_docstrings,
48
  )
49
+ from transformers.utils.import_utils import is_torch_fx_available
50
  from .configuration_gemma import GemmaConfig
51
 
52
 
 
54
  from flash_attn import flash_attn_func, flash_attn_varlen_func
55
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
56
 
57
+ from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
58
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
59
  # It means that the function will not be traced through and simply appear as a node in the graph.
60
  if is_torch_fx_available():
 
73
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
74
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
75
  max_seqlen_in_batch = seqlens_in_batch.max().item()
76
+ cu_seqlens = F.pad(
77
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
78
+ )
79
  return (
80
  indices,
81
  cu_seqlens,
 
83
  )
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
87
 
88
 
 
101
  # x: [bs, num_attention_heads, seq_len, head_size]
102
  if self.inv_freq is None:
103
  self.inv_freq = 1.0 / (
104
+ self.base
105
+ ** (
106
+ torch.arange(
107
+ 0, self.dim, 2, dtype=torch.int64, device=x.device
108
+ ).float()
109
+ / self.dim
110
+ )
111
  )
112
+ inv_freq_expanded = (
113
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
114
+ )
115
+ position_ids_expanded = position_ids[:, None, :].float()
116
  # Force float32 since bfloat16 loses precision on long contexts
117
  # See https://github.com/huggingface/transformers/pull/29285
118
  device_type = x.device.type
119
+ device_type = (
120
+ device_type
121
+ if isinstance(device_type, str) and device_type != "mps"
122
+ else "cpu"
123
+ )
124
  with torch.autocast(device_type=device_type, enabled=False):
125
+ freqs = (
126
+ inv_freq_expanded.float() @ position_ids_expanded.float()
127
+ ).transpose(1, 2)
128
  emb = torch.cat((freqs, freqs), dim=-1)
129
  cos = emb.cos()
130
  sin = emb.sin()
 
167
  return q_embed, k_embed
168
 
169
 
 
170
  class GemmaMLP(nn.Module):
171
  def __init__(self, config):
172
  super().__init__()
 
176
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
177
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
178
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
179
+ if config.hidden_activation is None:
180
+ logger.warning_once(
181
+ "Gemma's activation function should be approximate GeLU and not exact GeLU.\n"
182
+ "Changing the activation function to `gelu_pytorch_tanh`."
183
+ f"if you want to use the legacy `{config.hidden_act}`, "
184
+ f"edit the `model.config` to set `hidden_activation={config.hidden_act}` "
185
+ " instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
186
+ )
187
+ hidden_activation = "gelu_pytorch_tanh"
188
+ else:
189
+ hidden_activation = config.hidden_activation
190
+ self.act_fn = ACT2FN[hidden_activation]
191
 
192
  def forward(self, x):
193
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
202
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
203
  if n_rep == 1:
204
  return hidden_states
205
+ hidden_states = hidden_states[:, :, None, :, :].expand(
206
+ batch, num_key_value_heads, n_rep, slen, head_dim
207
+ )
208
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
209
 
210
 
 
239
  f" and `num_heads`: {self.num_heads})."
240
  )
241
 
242
+ self.q_proj = nn.Linear(
243
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
244
+ )
245
+ self.k_proj = nn.Linear(
246
+ self.hidden_size,
247
+ self.num_key_value_heads * self.head_dim,
248
+ bias=config.attention_bias,
249
+ )
250
+ self.v_proj = nn.Linear(
251
+ self.hidden_size,
252
+ self.num_key_value_heads * self.head_dim,
253
+ bias=config.attention_bias,
254
+ )
255
+ self.o_proj = nn.Linear(
256
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
257
+ )
258
  self.rotary_emb = GemmaRotaryEmbedding(
259
  self.head_dim,
260
  max_position_embeddings=self.max_position_embeddings,
 
278
  key_states = self.k_proj(hidden_states)
279
  value_states = self.v_proj(hidden_states)
280
 
281
+ query_states = query_states.view(
282
+ bsz, q_len, self.num_heads, self.head_dim
283
+ ).transpose(1, 2)
284
+ key_states = key_states.view(
285
+ bsz, q_len, self.num_key_value_heads, self.head_dim
286
+ ).transpose(1, 2)
287
+ value_states = value_states.view(
288
+ bsz, q_len, self.num_key_value_heads, self.head_dim
289
+ ).transpose(1, 2)
290
 
291
  past_key_value = getattr(self, "past_key_value", past_key_value)
292
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
293
+ query_states, key_states = apply_rotary_pos_emb(
294
+ query_states, key_states, cos, sin, None
295
+ )
296
 
297
  if past_key_value is not None:
298
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
299
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
300
+ key_states, value_states = past_key_value.update(
301
+ key_states, value_states, self.layer_idx, cache_kwargs
302
+ )
303
 
304
  key_states = repeat_kv(key_states, self.num_key_value_groups)
305
  value_states = repeat_kv(value_states, self.num_key_value_groups)
306
 
307
+ attn_weights = torch.matmul(
308
+ query_states, key_states.transpose(2, 3)
309
+ ) / math.sqrt(self.head_dim)
310
 
311
  if attention_mask is not None: # no matter the length, we just slice it
312
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
 
 
 
313
  attn_weights = attn_weights + causal_mask
314
 
315
  # upcast attention to fp32
316
+ attn_weights = nn.functional.softmax(
317
+ attn_weights, dim=-1, dtype=torch.float32
318
+ ).to(query_states.dtype)
319
+ attn_weights = nn.functional.dropout(
320
+ attn_weights, p=self.attention_dropout, training=self.training
321
+ )
322
  attn_output = torch.matmul(attn_weights, value_states)
323
 
324
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
377
  # Flash attention requires the input to have the shape
378
  # batch_size x seq_length x head_dim x hidden_dim
379
  # therefore we just need to keep the original shape
380
+ query_states = query_states.view(
381
+ bsz, q_len, self.num_heads, self.head_dim
382
+ ).transpose(1, 2)
383
+ key_states = key_states.view(
384
+ bsz, q_len, self.num_key_value_heads, self.head_dim
385
+ ).transpose(1, 2)
386
+ value_states = value_states.view(
387
+ bsz, q_len, self.num_key_value_heads, self.head_dim
388
+ ).transpose(1, 2)
389
 
390
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
391
+ query_states, key_states = apply_rotary_pos_emb(
392
+ query_states, key_states, cos, sin, None
393
+ )
394
 
395
  past_key_value = getattr(self, "past_key_value", past_key_value)
396
 
397
  if past_key_value is not None:
398
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
399
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
400
+ key_states, value_states = past_key_value.update(
401
+ key_states, value_states, self.layer_idx, cache_kwargs
402
+ )
403
 
404
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
405
  # to be able to avoid many of these transpose/reshape/view.
 
436
  value_states = value_states.to(target_dtype)
437
 
438
  attn_output = self._flash_attention_forward(
439
+ query_states,
440
+ key_states,
441
+ value_states,
442
+ attention_mask,
443
+ q_len,
444
+ dropout=dropout_rate,
445
  )
446
 
447
  attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
 
453
  return attn_output, attn_weights, past_key_value
454
 
455
  def _flash_attention_forward(
456
+ self,
457
+ query_states,
458
+ key_states,
459
+ value_states,
460
+ attention_mask,
461
+ query_length,
462
+ dropout=0.0,
463
+ softmax_scale=None,
464
  ):
465
  """
466
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
 
476
  attention_mask (`torch.Tensor`):
477
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
478
  position of padding tokens and 1 for the position of non-padding tokens.
479
+ dropout (`float`):
480
  Attention dropout
481
  softmax_scale (`float`, *optional*):
482
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
 
490
  # Contains at least one padding token in the sequence
491
  if attention_mask is not None:
492
  batch_size = query_states.shape[0]
493
+ (
494
+ query_states,
495
+ key_states,
496
+ value_states,
497
+ indices_q,
498
+ cu_seq_lens,
499
+ max_seq_lens,
500
+ ) = self._upad_input(
501
  query_states, key_states, value_states, attention_mask, query_length
502
  )
503
 
 
517
  causal=causal,
518
  )
519
 
520
+ attn_output = pad_input(
521
+ attn_output_unpad, indices_q, batch_size, query_length
522
+ )
523
  else:
524
  attn_output = flash_attn_func(
525
+ query_states,
526
+ key_states,
527
+ value_states,
528
+ dropout,
529
+ softmax_scale=softmax_scale,
530
+ causal=causal,
531
  )
532
 
533
  return attn_output
534
 
535
+ def _upad_input(
536
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
537
+ ):
538
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
539
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
540
 
541
  key_layer = index_first_axis(
542
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
543
+ indices_k,
544
  )
545
  value_layer = index_first_axis(
546
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
547
+ indices_k,
548
  )
549
  if query_length == kv_seq_len:
550
  query_layer = index_first_axis(
551
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
552
+ indices_k,
553
  )
554
  cu_seqlens_q = cu_seqlens_k
555
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
564
  else:
565
  # The -q_len: slice assumes left padding.
566
  attention_mask = attention_mask[:, -query_length:]
567
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
568
+ query_layer, attention_mask
569
+ )
570
 
571
  return (
572
  query_layer,
 
619
  key_states = self.k_proj(hidden_states)
620
  value_states = self.v_proj(hidden_states)
621
 
622
+ query_states = query_states.view(
623
+ bsz, q_len, self.num_heads, self.head_dim
624
+ ).transpose(1, 2)
625
+ key_states = key_states.view(
626
+ bsz, q_len, self.num_key_value_heads, self.head_dim
627
+ ).transpose(1, 2)
628
+ value_states = value_states.view(
629
+ bsz, q_len, self.num_key_value_heads, self.head_dim
630
+ ).transpose(1, 2)
631
 
632
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
633
+ query_states, key_states = apply_rotary_pos_emb(
634
+ query_states, key_states, cos, sin, None
635
+ )
636
 
637
  past_key_value = getattr(self, "past_key_value", past_key_value)
638
 
639
  if past_key_value is not None:
640
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
641
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
642
+ key_states, value_states = past_key_value.update(
643
+ key_states, value_states, self.layer_idx, cache_kwargs
644
+ )
645
 
646
  key_states = repeat_kv(key_states, self.num_key_value_groups)
647
  value_states = repeat_kv(value_states, self.num_key_value_groups)
648
 
649
  causal_mask = attention_mask
650
+ if attention_mask is not None:
651
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
652
 
653
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
654
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
655
+ if causal_mask is not None:
656
  query_states = query_states.contiguous()
657
  key_states = key_states.contiguous()
658
  value_states = value_states.contiguous()
659
 
660
+ # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
661
+ # relying on the `is_causal` argument.
662
  attn_output = torch.nn.functional.scaled_dot_product_attention(
663
  query_states,
664
  key_states,
665
  value_states,
666
+ attn_mask=causal_mask,
667
  dropout_p=self.attention_dropout if self.training else 0.0,
668
+ is_causal=causal_mask is None and q_len > 1,
669
  )
670
 
671
  attn_output = attn_output.transpose(1, 2).contiguous()
 
689
  super().__init__()
690
  self.hidden_size = config.hidden_size
691
 
692
+ self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](
693
+ config=config, layer_idx=layer_idx
694
+ )
695
 
696
  self.mlp = GemmaMLP(config)
697
  self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
698
+ self.post_attention_layernorm = GemmaRMSNorm(
699
+ config.hidden_size, eps=config.rms_norm_eps
700
+ )
701
 
702
  def forward(
703
  self,
 
709
  use_cache: Optional[bool] = False,
710
  cache_position: Optional[torch.LongTensor] = None,
711
  **kwargs,
712
+ ) -> Tuple[
713
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
714
+ ]:
715
  """
716
  Args:
717
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
746
  cache_position=cache_position,
747
  **kwargs,
748
  )
749
+ hidden_states = residual + hidden_states
750
 
751
  # Fully Connected
752
  residual = hidden_states
753
  hidden_states = self.post_attention_layernorm(hidden_states)
754
  hidden_states = self.mlp(hidden_states)
755
  hidden_states = residual + hidden_states
756
+
757
  outputs = (hidden_states,)
758
 
759
  if output_attentions:
 
808
  if module.padding_idx is not None:
809
  module.weight.data[module.padding_idx].zero_()
810
 
811
+ def _setup_cache(
812
+ self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None
813
+ ):
814
+ if (
815
+ self.config._attn_implementation == "flash_attention_2"
816
+ and cache_cls == StaticCache
817
+ ):
818
  raise ValueError(
819
  "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
820
  "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
821
  )
822
 
 
 
 
 
823
  for layer in self.model.layers:
824
  weights = layer.self_attn.o_proj.weight
825
  layer.self_attn.past_key_value = cache_cls(
826
+ self.config,
827
+ max_batch_size,
828
+ max_cache_len,
829
+ device=weights.device,
830
+ dtype=weights.dtype,
831
  )
832
 
833
  def _reset_cache(self):
 
902
  more detail.
903
  return_dict (`bool`, *optional*):
904
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
905
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
906
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
907
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
908
+ the complete sequence length.
909
  """
910
 
911
 
 
927
  self.padding_idx = config.pad_token_id
928
  self.vocab_size = config.vocab_size
929
 
930
+ self.embed_tokens = nn.Embedding(
931
+ config.vocab_size, config.hidden_size, self.padding_idx
932
+ )
933
  self.layers = nn.ModuleList(
934
+ [
935
+ GemmaDecoderLayer(config, layer_idx)
936
+ for layer_idx in range(config.num_hidden_layers)
937
+ ]
938
  )
939
  self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
940
  self.gradient_checkpointing = False
941
 
 
 
 
 
 
 
942
  # Initialize weights and apply final processing
943
  self.post_init()
944
 
 
963
  return_dict: Optional[bool] = None,
964
  cache_position: Optional[torch.LongTensor] = None,
965
  ) -> Union[Tuple, BaseModelOutputWithPast]:
966
+ output_attentions = (
967
+ output_attentions
968
+ if output_attentions is not None
969
+ else self.config.output_attentions
970
+ )
971
  output_hidden_states = (
972
+ output_hidden_states
973
+ if output_hidden_states is not None
974
+ else self.config.output_hidden_states
975
  )
976
  use_cache = use_cache if use_cache is not None else self.config.use_cache
977
+ return_dict = (
978
+ return_dict if return_dict is not None else self.config.use_return_dict
979
+ )
980
 
981
  if (input_ids is None) ^ (inputs_embeds is not None):
982
  raise ValueError(
 
1000
 
1001
  if cache_position is None:
1002
  cache_position = torch.arange(
1003
+ past_seen_tokens,
1004
+ past_seen_tokens + inputs_embeds.shape[1],
1005
+ device=inputs_embeds.device,
1006
  )
1007
 
1008
  if position_ids is None:
1009
  position_ids = cache_position.unsqueeze(0)
1010
 
1011
+ causal_mask = self._update_causal_mask(
1012
+ attention_mask, inputs_embeds, cache_position, past_seen_tokens
1013
+ )
1014
 
1015
  # embed positions
1016
  hidden_states = inputs_embeds
1017
 
1018
  # normalized
1019
+ # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
1020
+ # See https://github.com/huggingface/transformers/pull/29402
1021
+ normalizer = torch.tensor(
1022
+ self.config.hidden_size**0.5, dtype=hidden_states.dtype
1023
+ )
1024
+ hidden_states = hidden_states * normalizer
1025
 
1026
  # decoder layers
1027
  all_hidden_states = () if output_hidden_states else None
 
1071
  next_cache = None
1072
  if use_cache:
1073
  next_cache = (
1074
+ next_decoder_cache.to_legacy_cache()
1075
+ if isinstance(next_decoder_cache, Cache)
1076
+ else next_decoder_cache
1077
  )
1078
  if not return_dict:
1079
+ return tuple(
1080
+ v
1081
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1082
+ if v is not None
1083
+ )
1084
  return BaseModelOutputWithPast(
1085
  last_hidden_state=hidden_states,
1086
  past_key_values=next_cache,
 
1088
  attentions=all_self_attns,
1089
  )
1090
 
1091
+ def _update_causal_mask(
1092
+ self,
1093
+ attention_mask: torch.Tensor,
1094
+ input_tensor: torch.Tensor,
1095
+ cache_position: torch.Tensor,
1096
+ past_seen_tokens: int,
1097
+ ):
1098
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1099
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1100
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1101
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1102
+
1103
  if self.config._attn_implementation == "flash_attention_2":
1104
  if attention_mask is not None and 0.0 in attention_mask:
1105
  return attention_mask
1106
  return None
1107
 
1108
+ if self.config._attn_implementation == "sdpa":
1109
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
1110
+ # in order to dispatch on Flash Attention 2.
1111
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1112
+ attention_mask,
1113
+ inputs_embeds=input_tensor,
1114
+ past_key_values_length=past_seen_tokens,
1115
+ ):
1116
+ return None
1117
 
1118
+ dtype, device = input_tensor.dtype, input_tensor.device
1119
  min_dtype = torch.finfo(dtype).min
1120
+ sequence_length = input_tensor.shape[1]
1121
+ if hasattr(
1122
+ getattr(self.layers[0], "self_attn", {}), "past_key_value"
1123
+ ): # static cache
1124
+ target_length = self.config.max_position_embeddings
1125
+ else: # dynamic cache
1126
+ target_length = (
1127
+ attention_mask.shape[-1]
1128
+ if isinstance(attention_mask, torch.Tensor)
1129
+ else past_seen_tokens + sequence_length + 1
1130
+ )
1131
+
1132
+ causal_mask = torch.full(
1133
+ (sequence_length, target_length),
1134
+ fill_value=min_dtype,
1135
+ dtype=dtype,
1136
+ device=device,
1137
+ )
1138
+ if sequence_length != 1:
1139
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1140
+ causal_mask *= torch.arange(
1141
+ target_length, device=device
1142
+ ) > cache_position.reshape(-1, 1)
1143
+ causal_mask = causal_mask[None, None, :, :].expand(
1144
+ input_tensor.shape[0], 1, -1, -1
1145
+ )
1146
+ if attention_mask is not None:
1147
+ causal_mask = (
1148
+ causal_mask.clone()
1149
+ ) # copy to contiguous memory for in-place edit
1150
+ if attention_mask.dim() == 2:
1151
+ mask_length = attention_mask.shape[-1]
1152
+ padding_mask = (
1153
+ causal_mask[:, :, :, :mask_length]
1154
+ + attention_mask[:, None, None, :]
1155
+ )
1156
+ padding_mask = padding_mask == 0
1157
+ causal_mask[:, :, :, :mask_length] = causal_mask[
1158
+ :, :, :, :mask_length
1159
+ ].masked_fill(padding_mask, min_dtype)
1160
+ elif attention_mask.dim() == 4:
1161
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1162
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1163
+ if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1164
+ offset = cache_position[0]
1165
+ else:
1166
+ offset = 0
1167
+ mask_shape = attention_mask.shape
1168
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1169
+ causal_mask[
1170
+ : mask_shape[0],
1171
+ : mask_shape[1],
1172
+ offset : mask_shape[2] + offset,
1173
+ : mask_shape[3],
1174
+ ] = mask_slice
1175
+
1176
+ if (
1177
+ self.config._attn_implementation == "sdpa"
1178
+ and attention_mask is not None
1179
+ and attention_mask.device.type == "cuda"
1180
+ ):
1181
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1182
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1183
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1184
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1185
+ causal_mask, min_dtype
1186
  )
 
 
 
 
 
1187
 
1188
  return causal_mask
1189
 
 
1221
 
1222
  # Ignore copy
1223
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
1224
+ @replace_return_docstrings(
1225
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1226
+ )
1227
  def forward(
1228
  self,
1229
  input_ids: torch.LongTensor = None,
 
1263
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1264
  "What is your favorite condiment?"
1265
  ```"""
1266
+ output_attentions = (
1267
+ output_attentions
1268
+ if output_attentions is not None
1269
+ else self.config.output_attentions
1270
+ )
1271
  output_hidden_states = (
1272
+ output_hidden_states
1273
+ if output_hidden_states is not None
1274
+ else self.config.output_hidden_states
1275
+ )
1276
+ return_dict = (
1277
+ return_dict if return_dict is not None else self.config.use_return_dict
1278
  )
 
1279
 
1280
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1281
  outputs = self.model(
 
1320
  )
1321
 
1322
  def prepare_inputs_for_generation(
1323
+ self,
1324
+ input_ids,
1325
+ past_key_values=None,
1326
+ attention_mask=None,
1327
+ inputs_embeds=None,
1328
+ cache_position=None,
1329
+ **kwargs,
1330
  ):
1331
+ # With static cache, the `past_key_values` is None
1332
+ # TODO joao: standardize interface for the different Cache classes and remove of this if
1333
+ has_static_cache = False
1334
+ if past_key_values is None:
1335
+ past_key_values = getattr(
1336
+ getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None
1337
+ )
1338
+ has_static_cache = past_key_values is not None
1339
+
1340
  past_length = 0
1341
  if past_key_values is not None:
1342
  if isinstance(past_key_values, Cache):
1343
+ past_length = (
1344
+ cache_position[0]
1345
+ if cache_position is not None
1346
+ else past_key_values.get_seq_length()
1347
+ )
1348
+ max_cache_length = (
1349
+ torch.tensor(
1350
+ past_key_values.get_max_length(), device=input_ids.device
1351
+ )
1352
+ if past_key_values.get_max_length() is not None
1353
+ else None
1354
+ )
1355
+ cache_length = (
1356
+ past_length
1357
+ if max_cache_length is None
1358
+ else torch.min(max_cache_length, past_length)
1359
+ )
1360
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1361
  else:
1362
  cache_length = past_length = past_key_values[0][0].shape[2]
1363
  max_cache_length = None
 
1366
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1367
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1368
  # input)
1369
+ if (
1370
+ attention_mask is not None
1371
+ and attention_mask.shape[1] > input_ids.shape[1]
1372
+ ):
1373
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1374
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1375
  # input_ids based on the past_length.
 
1393
  if past_key_values:
1394
  position_ids = position_ids[:, -input_ids.shape[1] :]
1395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1396
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1397
  if inputs_embeds is not None and past_key_values is None:
1398
  model_inputs = {"inputs_embeds": inputs_embeds}
 
1402
  # TODO: use `next_tokens` directly instead.
1403
  model_inputs = {"input_ids": input_ids.contiguous()}
1404
 
1405
+ input_length = (
1406
+ position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1407
+ )
1408
+ if cache_position is None:
1409
+ cache_position = torch.arange(
1410
+ past_length, past_length + input_length, device=input_ids.device
1411
+ )
1412
+ else:
1413
+ cache_position = cache_position[-input_length:]
1414
+
1415
+ if has_static_cache:
1416
+ past_key_values = None
1417
+
1418
  model_inputs.update(
1419
  {
1420
+ "position_ids": position_ids,
1421
  "cache_position": cache_position,
1422
  "past_key_values": past_key_values,
1423
  "use_cache": kwargs.get("use_cache"),
 
1431
  reordered_past = ()
1432
  for layer_past in past_key_values:
1433
  reordered_past += (
1434
+ tuple(
1435
+ past_state.index_select(0, beam_idx.to(past_state.device))
1436
+ for past_state in layer_past
1437
+ ),
1438
  )
1439
  return reordered_past
1440
 
 
1491
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1492
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1493
  """
1494
+ return_dict = (
1495
+ return_dict if return_dict is not None else self.config.use_return_dict
1496
+ )
1497
 
1498
  transformer_outputs = self.model(
1499
  input_ids,
 
1515
  batch_size = inputs_embeds.shape[0]
1516
 
1517
  if self.config.pad_token_id is None and batch_size != 1:
1518
+ raise ValueError(
1519
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1520
+ )
1521
  if self.config.pad_token_id is None:
1522
  sequence_lengths = -1
1523
  else:
1524
  if input_ids is not None:
1525
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1526
+ sequence_lengths = (
1527
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1528
+ )
1529
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1530
  sequence_lengths = sequence_lengths.to(logits.device)
1531
  else:
1532
  sequence_lengths = -1
1533
 
1534
+ pooled_logits = logits[
1535
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1536
+ ]
1537
 
1538
  loss = None
1539
  if labels is not None:
 
1541
  if self.config.problem_type is None:
1542
  if self.num_labels == 1:
1543
  self.config.problem_type = "regression"
1544
+ elif self.num_labels > 1 and (
1545
+ labels.dtype == torch.long or labels.dtype == torch.int
1546
+ ):
1547
  self.config.problem_type = "single_label_classification"
1548
  else:
1549
  self.config.problem_type = "multi_label_classification"
 
1556
  loss = loss_fct(pooled_logits, labels)
1557
  elif self.config.problem_type == "single_label_classification":
1558
  loss_fct = CrossEntropyLoss()
1559
+ loss = loss_fct(
1560
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1561
+ )
1562
  elif self.config.problem_type == "multi_label_classification":
1563
  loss_fct = BCEWithLogitsLoss()
1564
  loss = loss_fct(pooled_logits, labels)