zifei9 commited on
Commit
8745eab
·
verified ·
1 Parent(s): e4a91a4

Update modeling_gemma.py

Browse files

rebasing on latest transformers

Files changed (1) hide show
  1. modeling_gemma.py +387 -573
modeling_gemma.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  # coding=utf-8
2
  # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
3
  #
@@ -13,74 +19,111 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
- """ PyTorch Gemma model."""
17
-
18
  import math
19
- import warnings
20
  from typing import List, Optional, Tuple, Union
21
 
22
  import torch
23
- import torch.nn.functional as F
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
28
- from transformers.activations import ACT2FN
29
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
- from transformers.modeling_attn_mask_utils import (
31
- AttentionMaskConverter,
32
- _prepare_4d_causal_attention_mask,
33
- )
34
- from transformers.modeling_outputs import (
35
  BaseModelOutputWithPast,
36
  CausalLMOutputWithPast,
37
  SequenceClassifierOutputWithPast,
 
38
  )
39
- from transformers.modeling_utils import PreTrainedModel
40
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
41
- from transformers.utils import (
42
  add_start_docstrings,
43
  add_start_docstrings_to_model_forward,
44
- is_flash_attn_2_available,
45
  is_flash_attn_greater_or_equal_2_10,
46
  logging,
47
  replace_return_docstrings,
48
  )
49
- from transformers.utils.import_utils import is_torch_fx_available
50
  from .configuration_gemma import GemmaConfig
51
 
52
 
53
- if is_flash_attn_2_available():
54
- from flash_attn import flash_attn_func, flash_attn_varlen_func
55
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
58
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
59
- # It means that the function will not be traced through and simply appear as a node in the graph.
60
- if is_torch_fx_available():
61
- if not is_torch_greater_or_equal_than_1_13:
62
- import torch.fx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
65
 
66
 
67
- logger = logging.get_logger(__name__)
 
 
 
 
68
 
69
- _CONFIG_FOR_DOC = "GemmaConfig"
 
70
 
 
 
 
 
 
 
71
 
72
- def _get_unpad_data(attention_mask):
73
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
74
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
75
- max_seqlen_in_batch = seqlens_in_batch.max().item()
76
- cu_seqlens = F.pad(
77
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
78
- )
79
- return (
80
- indices,
81
- cu_seqlens,
82
- max_seqlen_in_batch,
83
- )
84
 
85
 
86
  ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
@@ -93,45 +136,80 @@ class GemmaRotaryEmbedding(nn.Module):
93
  self.dim = dim
94
  self.max_position_embeddings = max_position_embeddings
95
  self.base = base
96
- # self.register_buffer("inv_freq", None, persistent=False)
97
- self.inv_freq = None
98
 
99
  @torch.no_grad()
100
  def forward(self, x, position_ids, seq_len=None):
101
  # x: [bs, num_attention_heads, seq_len, head_size]
102
- if self.inv_freq is None:
103
- self.inv_freq = 1.0 / (
104
- self.base
105
- ** (
106
- torch.arange(
107
- 0, self.dim, 2, dtype=torch.int64, device=x.device
108
- ).float()
109
- / self.dim
110
- )
111
- )
112
- inv_freq_expanded = (
113
- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
114
- )
115
  position_ids_expanded = position_ids[:, None, :].float()
116
  # Force float32 since bfloat16 loses precision on long contexts
117
  # See https://github.com/huggingface/transformers/pull/29285
118
  device_type = x.device.type
119
- device_type = (
120
- device_type
121
- if isinstance(device_type, str) and device_type != "mps"
122
- else "cpu"
123
- )
124
  with torch.autocast(device_type=device_type, enabled=False):
125
- freqs = (
126
- inv_freq_expanded.float() @ position_ids_expanded.float()
127
- ).transpose(1, 2)
128
  emb = torch.cat((freqs, freqs), dim=-1)
129
  cos = emb.cos()
130
  sin = emb.sin()
131
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
132
 
133
 
134
- # Copied from transformers.models.llama.modeling_llama.rotate_half
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def rotate_half(x):
136
  """Rotates half the hidden dims of the input."""
137
  x1 = x[..., : x.shape[-1] // 2]
@@ -139,7 +217,6 @@ def rotate_half(x):
139
  return torch.cat((-x2, x1), dim=-1)
140
 
141
 
142
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
143
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
144
  """Applies Rotary Position Embedding to the query and key tensors.
145
 
@@ -167,33 +244,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
167
  return q_embed, k_embed
168
 
169
 
170
- class GemmaMLP(nn.Module):
171
- def __init__(self, config):
172
- super().__init__()
173
- self.config = config
174
- self.hidden_size = config.hidden_size
175
- self.intermediate_size = config.intermediate_size
176
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
177
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
178
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
179
- if config.hidden_activation is None:
180
- logger.warning_once(
181
- "Gemma's activation function should be approximate GeLU and not exact GeLU.\n"
182
- "Changing the activation function to `gelu_pytorch_tanh`."
183
- f"if you want to use the legacy `{config.hidden_act}`, "
184
- f"edit the `model.config` to set `hidden_activation={config.hidden_act}` "
185
- " instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
186
- )
187
- hidden_activation = "gelu_pytorch_tanh"
188
- else:
189
- hidden_activation = config.hidden_activation
190
- self.act_fn = ACT2FN[hidden_activation]
191
-
192
- def forward(self, x):
193
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
194
-
195
-
196
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
197
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
198
  """
199
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -202,16 +252,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
202
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
203
  if n_rep == 1:
204
  return hidden_states
205
- hidden_states = hidden_states[:, :, None, :, :].expand(
206
- batch, num_key_value_heads, n_rep, slen, head_dim
207
- )
208
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
209
 
210
 
211
  class GemmaAttention(nn.Module):
212
  """Multi-headed attention from 'Attention Is All You Need' paper"""
213
 
214
- # Ignore copy
215
  def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
216
  super().__init__()
217
  self.config = config
@@ -232,6 +279,7 @@ class GemmaAttention(nn.Module):
232
  self.max_position_embeddings = config.max_position_embeddings
233
  self.rope_theta = config.rope_theta
234
  self.is_causal = True
 
235
 
236
  if self.hidden_size % self.num_heads != 0:
237
  raise ValueError(
@@ -239,22 +287,10 @@ class GemmaAttention(nn.Module):
239
  f" and `num_heads`: {self.num_heads})."
240
  )
241
 
242
- self.q_proj = nn.Linear(
243
- self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
244
- )
245
- self.k_proj = nn.Linear(
246
- self.hidden_size,
247
- self.num_key_value_heads * self.head_dim,
248
- bias=config.attention_bias,
249
- )
250
- self.v_proj = nn.Linear(
251
- self.hidden_size,
252
- self.num_key_value_heads * self.head_dim,
253
- bias=config.attention_bias,
254
- )
255
- self.o_proj = nn.Linear(
256
- self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
257
- )
258
  self.rotary_emb = GemmaRotaryEmbedding(
259
  self.head_dim,
260
  max_position_embeddings=self.max_position_embeddings,
@@ -270,7 +306,6 @@ class GemmaAttention(nn.Module):
270
  output_attentions: bool = False,
271
  use_cache: bool = False,
272
  cache_position: Optional[torch.LongTensor] = None,
273
- **kwargs,
274
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
275
  bsz, q_len, _ = hidden_states.size()
276
 
@@ -278,47 +313,30 @@ class GemmaAttention(nn.Module):
278
  key_states = self.k_proj(hidden_states)
279
  value_states = self.v_proj(hidden_states)
280
 
281
- query_states = query_states.view(
282
- bsz, q_len, self.num_heads, self.head_dim
283
- ).transpose(1, 2)
284
- key_states = key_states.view(
285
- bsz, q_len, self.num_key_value_heads, self.head_dim
286
- ).transpose(1, 2)
287
- value_states = value_states.view(
288
- bsz, q_len, self.num_key_value_heads, self.head_dim
289
- ).transpose(1, 2)
290
-
291
- past_key_value = getattr(self, "past_key_value", past_key_value)
292
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
293
- query_states, key_states = apply_rotary_pos_emb(
294
- query_states, key_states, cos, sin, None
295
- )
296
 
297
  if past_key_value is not None:
298
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
299
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
300
- key_states, value_states = past_key_value.update(
301
- key_states, value_states, self.layer_idx, cache_kwargs
302
- )
303
 
304
  key_states = repeat_kv(key_states, self.num_key_value_groups)
305
  value_states = repeat_kv(value_states, self.num_key_value_groups)
306
 
307
- attn_weights = torch.matmul(
308
- query_states, key_states.transpose(2, 3)
309
- ) / math.sqrt(self.head_dim)
310
 
311
  if attention_mask is not None: # no matter the length, we just slice it
312
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
313
  attn_weights = attn_weights + causal_mask
314
 
315
  # upcast attention to fp32
316
- attn_weights = nn.functional.softmax(
317
- attn_weights, dim=-1, dtype=torch.float32
318
- ).to(query_states.dtype)
319
- attn_weights = nn.functional.dropout(
320
- attn_weights, p=self.attention_dropout, training=self.training
321
- )
322
  attn_output = torch.matmul(attn_weights, value_states)
323
 
324
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -338,7 +356,6 @@ class GemmaAttention(nn.Module):
338
  return attn_output, attn_weights, past_key_value
339
 
340
 
341
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemma
342
  class GemmaFlashAttention2(GemmaAttention):
343
  """
344
  Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
@@ -354,7 +371,6 @@ class GemmaFlashAttention2(GemmaAttention):
354
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
355
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
356
 
357
- # Ignore copy
358
  def forward(
359
  self,
360
  hidden_states: torch.Tensor,
@@ -364,8 +380,13 @@ class GemmaFlashAttention2(GemmaAttention):
364
  output_attentions: bool = False,
365
  use_cache: bool = False,
366
  cache_position: Optional[torch.LongTensor] = None,
367
- **kwargs,
368
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
369
  output_attentions = False
370
 
371
  bsz, q_len, _ = hidden_states.size()
@@ -377,29 +398,17 @@ class GemmaFlashAttention2(GemmaAttention):
377
  # Flash attention requires the input to have the shape
378
  # batch_size x seq_length x head_dim x hidden_dim
379
  # therefore we just need to keep the original shape
380
- query_states = query_states.view(
381
- bsz, q_len, self.num_heads, self.head_dim
382
- ).transpose(1, 2)
383
- key_states = key_states.view(
384
- bsz, q_len, self.num_key_value_heads, self.head_dim
385
- ).transpose(1, 2)
386
- value_states = value_states.view(
387
- bsz, q_len, self.num_key_value_heads, self.head_dim
388
- ).transpose(1, 2)
389
-
390
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
391
- query_states, key_states = apply_rotary_pos_emb(
392
- query_states, key_states, cos, sin, None
393
- )
394
 
395
- past_key_value = getattr(self, "past_key_value", past_key_value)
 
396
 
397
  if past_key_value is not None:
398
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
399
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
400
- key_states, value_states = past_key_value.update(
401
- key_states, value_states, self.layer_idx, cache_kwargs
402
- )
403
 
404
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
405
  # to be able to avoid many of these transpose/reshape/view.
@@ -435,13 +444,17 @@ class GemmaFlashAttention2(GemmaAttention):
435
  key_states = key_states.to(target_dtype)
436
  value_states = value_states.to(target_dtype)
437
 
438
- attn_output = self._flash_attention_forward(
439
  query_states,
440
  key_states,
441
  value_states,
442
  attention_mask,
443
  q_len,
 
444
  dropout=dropout_rate,
 
 
 
445
  )
446
 
447
  attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -452,133 +465,7 @@ class GemmaFlashAttention2(GemmaAttention):
452
 
453
  return attn_output, attn_weights, past_key_value
454
 
455
- def _flash_attention_forward(
456
- self,
457
- query_states,
458
- key_states,
459
- value_states,
460
- attention_mask,
461
- query_length,
462
- dropout=0.0,
463
- softmax_scale=None,
464
- ):
465
- """
466
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
467
- first unpad the input, then computes the attention scores and pad the final attention scores.
468
-
469
- Args:
470
- query_states (`torch.Tensor`):
471
- Input query states to be passed to Flash Attention API
472
- key_states (`torch.Tensor`):
473
- Input key states to be passed to Flash Attention API
474
- value_states (`torch.Tensor`):
475
- Input value states to be passed to Flash Attention API
476
- attention_mask (`torch.Tensor`):
477
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
478
- position of padding tokens and 1 for the position of non-padding tokens.
479
- dropout (`float`):
480
- Attention dropout
481
- softmax_scale (`float`, *optional*):
482
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
483
- """
484
- if not self._flash_attn_uses_top_left_mask:
485
- causal = self.is_causal
486
- else:
487
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in GemmaFlashAttention2 __init__.
488
- causal = self.is_causal and query_length != 1
489
-
490
- # Contains at least one padding token in the sequence
491
- if attention_mask is not None:
492
- batch_size = query_states.shape[0]
493
- (
494
- query_states,
495
- key_states,
496
- value_states,
497
- indices_q,
498
- cu_seq_lens,
499
- max_seq_lens,
500
- ) = self._upad_input(
501
- query_states, key_states, value_states, attention_mask, query_length
502
- )
503
 
504
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
505
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
506
-
507
- attn_output_unpad = flash_attn_varlen_func(
508
- query_states,
509
- key_states,
510
- value_states,
511
- cu_seqlens_q=cu_seqlens_q,
512
- cu_seqlens_k=cu_seqlens_k,
513
- max_seqlen_q=max_seqlen_in_batch_q,
514
- max_seqlen_k=max_seqlen_in_batch_k,
515
- dropout_p=dropout,
516
- softmax_scale=softmax_scale,
517
- causal=causal,
518
- )
519
-
520
- attn_output = pad_input(
521
- attn_output_unpad, indices_q, batch_size, query_length
522
- )
523
- else:
524
- attn_output = flash_attn_func(
525
- query_states,
526
- key_states,
527
- value_states,
528
- dropout,
529
- softmax_scale=softmax_scale,
530
- causal=causal,
531
- )
532
-
533
- return attn_output
534
-
535
- def _upad_input(
536
- self, query_layer, key_layer, value_layer, attention_mask, query_length
537
- ):
538
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
539
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
540
-
541
- key_layer = index_first_axis(
542
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
543
- indices_k,
544
- )
545
- value_layer = index_first_axis(
546
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
547
- indices_k,
548
- )
549
- if query_length == kv_seq_len:
550
- query_layer = index_first_axis(
551
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
552
- indices_k,
553
- )
554
- cu_seqlens_q = cu_seqlens_k
555
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
556
- indices_q = indices_k
557
- elif query_length == 1:
558
- max_seqlen_in_batch_q = 1
559
- cu_seqlens_q = torch.arange(
560
- batch_size + 1, dtype=torch.int32, device=query_layer.device
561
- ) # There is a memcpy here, that is very bad.
562
- indices_q = cu_seqlens_q[:-1]
563
- query_layer = query_layer.squeeze(1)
564
- else:
565
- # The -q_len: slice assumes left padding.
566
- attention_mask = attention_mask[:, -query_length:]
567
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
568
- query_layer, attention_mask
569
- )
570
-
571
- return (
572
- query_layer,
573
- key_layer,
574
- value_layer,
575
- indices_q,
576
- (cu_seqlens_q, cu_seqlens_k),
577
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
578
- )
579
-
580
-
581
- # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemma
582
  class GemmaSdpaAttention(GemmaAttention):
583
  """
584
  Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -586,7 +473,7 @@ class GemmaSdpaAttention(GemmaAttention):
586
  SDPA API.
587
  """
588
 
589
- # Ignore copy
590
  def forward(
591
  self,
592
  hidden_states: torch.Tensor,
@@ -596,6 +483,7 @@ class GemmaSdpaAttention(GemmaAttention):
596
  output_attentions: bool = False,
597
  use_cache: bool = False,
598
  cache_position: Optional[torch.LongTensor] = None,
 
599
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
600
  if output_attentions:
601
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -619,29 +507,17 @@ class GemmaSdpaAttention(GemmaAttention):
619
  key_states = self.k_proj(hidden_states)
620
  value_states = self.v_proj(hidden_states)
621
 
622
- query_states = query_states.view(
623
- bsz, q_len, self.num_heads, self.head_dim
624
- ).transpose(1, 2)
625
- key_states = key_states.view(
626
- bsz, q_len, self.num_key_value_heads, self.head_dim
627
- ).transpose(1, 2)
628
- value_states = value_states.view(
629
- bsz, q_len, self.num_key_value_heads, self.head_dim
630
- ).transpose(1, 2)
631
-
632
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
633
- query_states, key_states = apply_rotary_pos_emb(
634
- query_states, key_states, cos, sin, None
635
- )
636
 
637
- past_key_value = getattr(self, "past_key_value", past_key_value)
 
638
 
639
  if past_key_value is not None:
640
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
641
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
642
- key_states, value_states = past_key_value.update(
643
- key_states, value_states, self.layer_idx, cache_kwargs
644
- )
645
 
646
  key_states = repeat_kv(key_states, self.num_key_value_groups)
647
  value_states = repeat_kv(value_states, self.num_key_value_groups)
@@ -657,15 +533,17 @@ class GemmaSdpaAttention(GemmaAttention):
657
  key_states = key_states.contiguous()
658
  value_states = value_states.contiguous()
659
 
660
- # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
661
- # relying on the `is_causal` argument.
 
 
662
  attn_output = torch.nn.functional.scaled_dot_product_attention(
663
  query_states,
664
  key_states,
665
  value_states,
666
  attn_mask=causal_mask,
667
  dropout_p=self.attention_dropout if self.training else 0.0,
668
- is_causal=causal_mask is None and q_len > 1,
669
  )
670
 
671
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -683,35 +561,28 @@ GEMMA_ATTENTION_CLASSES = {
683
  }
684
 
685
 
686
- # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
687
  class GemmaDecoderLayer(nn.Module):
688
  def __init__(self, config: GemmaConfig, layer_idx: int):
689
  super().__init__()
690
  self.hidden_size = config.hidden_size
691
 
692
- self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](
693
- config=config, layer_idx=layer_idx
694
- )
695
 
696
  self.mlp = GemmaMLP(config)
697
  self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
698
- self.post_attention_layernorm = GemmaRMSNorm(
699
- config.hidden_size, eps=config.rms_norm_eps
700
- )
701
 
702
  def forward(
703
  self,
704
  hidden_states: torch.Tensor,
705
  attention_mask: Optional[torch.Tensor] = None,
706
  position_ids: Optional[torch.LongTensor] = None,
707
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
708
  output_attentions: Optional[bool] = False,
709
  use_cache: Optional[bool] = False,
710
  cache_position: Optional[torch.LongTensor] = None,
711
  **kwargs,
712
- ) -> Tuple[
713
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
714
- ]:
715
  """
716
  Args:
717
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -725,12 +596,12 @@ class GemmaDecoderLayer(nn.Module):
725
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
726
  (see `past_key_values`).
727
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
 
 
 
 
 
728
  """
729
- if "padding_mask" in kwargs:
730
- warnings.warn(
731
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
732
- )
733
-
734
  residual = hidden_states
735
 
736
  hidden_states = self.input_layernorm(hidden_states)
@@ -790,12 +661,13 @@ class GemmaPreTrainedModel(PreTrainedModel):
790
  config_class = GemmaConfig
791
  base_model_prefix = "model"
792
  supports_gradient_checkpointing = True
793
- _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
794
  _no_split_modules = ["GemmaDecoderLayer"]
795
- _skip_keys_device_placement = ["past_key_values", "causal_mask"]
796
  _supports_flash_attn_2 = True
797
  _supports_sdpa = True
798
  _supports_cache_class = True
 
 
799
 
800
  def _init_weights(self, module):
801
  std = self.config.initializer_range
@@ -808,31 +680,8 @@ class GemmaPreTrainedModel(PreTrainedModel):
808
  if module.padding_idx is not None:
809
  module.weight.data[module.padding_idx].zero_()
810
 
811
- def _setup_cache(
812
- self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None
813
- ):
814
- if (
815
- self.config._attn_implementation == "flash_attention_2"
816
- and cache_cls == StaticCache
817
- ):
818
- raise ValueError(
819
- "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
820
- "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
821
- )
822
-
823
- for layer in self.model.layers:
824
- weights = layer.self_attn.o_proj.weight
825
- layer.self_attn.past_key_value = cache_cls(
826
- self.config,
827
- max_batch_size,
828
- max_cache_len,
829
- device=weights.device,
830
- dtype=weights.dtype,
831
- )
832
 
833
- def _reset_cache(self):
834
- for layer in self.model.layers:
835
- layer.self_attn.past_key_value = None
836
 
837
 
838
  GEMMA_INPUTS_DOCSTRING = r"""
@@ -913,7 +762,6 @@ GEMMA_INPUTS_DOCSTRING = r"""
913
  "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
914
  GEMMA_START_DOCSTRING,
915
  )
916
- # Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMA,Llama->Gemma
917
  class GemmaModel(GemmaPreTrainedModel):
918
  """
919
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
@@ -927,14 +775,9 @@ class GemmaModel(GemmaPreTrainedModel):
927
  self.padding_idx = config.pad_token_id
928
  self.vocab_size = config.vocab_size
929
 
930
- self.embed_tokens = nn.Embedding(
931
- config.vocab_size, config.hidden_size, self.padding_idx
932
- )
933
  self.layers = nn.ModuleList(
934
- [
935
- GemmaDecoderLayer(config, layer_idx)
936
- for layer_idx in range(config.num_hidden_layers)
937
- ]
938
  )
939
  self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
940
  self.gradient_checkpointing = False
@@ -949,13 +792,12 @@ class GemmaModel(GemmaPreTrainedModel):
949
  self.embed_tokens = value
950
 
951
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
952
- # Ignore copy
953
  def forward(
954
  self,
955
  input_ids: torch.LongTensor = None,
956
  attention_mask: Optional[torch.Tensor] = None,
957
  position_ids: Optional[torch.LongTensor] = None,
958
- past_key_values: Optional[List[torch.FloatTensor]] = None,
959
  inputs_embeds: Optional[torch.FloatTensor] = None,
960
  use_cache: Optional[bool] = None,
961
  output_attentions: Optional[bool] = None,
@@ -963,20 +805,12 @@ class GemmaModel(GemmaPreTrainedModel):
963
  return_dict: Optional[bool] = None,
964
  cache_position: Optional[torch.LongTensor] = None,
965
  ) -> Union[Tuple, BaseModelOutputWithPast]:
966
- output_attentions = (
967
- output_attentions
968
- if output_attentions is not None
969
- else self.config.output_attentions
970
- )
971
  output_hidden_states = (
972
- output_hidden_states
973
- if output_hidden_states is not None
974
- else self.config.output_hidden_states
975
  )
976
  use_cache = use_cache if use_cache is not None else self.config.use_cache
977
- return_dict = (
978
- return_dict if return_dict is not None else self.config.use_return_dict
979
- )
980
 
981
  if (input_ids is None) ^ (inputs_embeds is not None):
982
  raise ValueError(
@@ -992,24 +826,24 @@ class GemmaModel(GemmaPreTrainedModel):
992
  if inputs_embeds is None:
993
  inputs_embeds = self.embed_tokens(input_ids)
994
 
995
- past_seen_tokens = 0
996
- if use_cache: # kept for BC (cache positions)
997
- if not isinstance(past_key_values, StaticCache):
998
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
999
- past_seen_tokens = past_key_values.get_seq_length()
 
1000
 
1001
  if cache_position is None:
 
1002
  cache_position = torch.arange(
1003
- past_seen_tokens,
1004
- past_seen_tokens + inputs_embeds.shape[1],
1005
- device=inputs_embeds.device,
1006
  )
1007
 
1008
  if position_ids is None:
1009
  position_ids = cache_position.unsqueeze(0)
1010
 
1011
  causal_mask = self._update_causal_mask(
1012
- attention_mask, inputs_embeds, cache_position, past_seen_tokens
1013
  )
1014
 
1015
  # embed positions
@@ -1018,10 +852,17 @@ class GemmaModel(GemmaPreTrainedModel):
1018
  # normalized
1019
  # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
1020
  # See https://github.com/huggingface/transformers/pull/29402
1021
- normalizer = torch.tensor(
1022
- self.config.hidden_size**0.5, dtype=hidden_states.dtype
1023
- )
1024
  hidden_states = hidden_states * normalizer
 
 
 
 
 
 
 
 
 
1025
 
1026
  # decoder layers
1027
  all_hidden_states = () if output_hidden_states else None
@@ -1068,19 +909,12 @@ class GemmaModel(GemmaPreTrainedModel):
1068
  if output_hidden_states:
1069
  all_hidden_states += (hidden_states,)
1070
 
1071
- next_cache = None
1072
- if use_cache:
1073
- next_cache = (
1074
- next_decoder_cache.to_legacy_cache()
1075
- if isinstance(next_decoder_cache, Cache)
1076
- else next_decoder_cache
1077
- )
1078
  if not return_dict:
1079
- return tuple(
1080
- v
1081
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1082
- if v is not None
1083
- )
1084
  return BaseModelOutputWithPast(
1085
  last_hidden_state=hidden_states,
1086
  past_key_values=next_cache,
@@ -1093,7 +927,8 @@ class GemmaModel(GemmaPreTrainedModel):
1093
  attention_mask: torch.Tensor,
1094
  input_tensor: torch.Tensor,
1095
  cache_position: torch.Tensor,
1096
- past_seen_tokens: int,
 
1097
  ):
1098
  # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1099
  # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
@@ -1105,90 +940,59 @@ class GemmaModel(GemmaPreTrainedModel):
1105
  return attention_mask
1106
  return None
1107
 
1108
- if self.config._attn_implementation == "sdpa":
1109
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
1110
- # in order to dispatch on Flash Attention 2.
 
 
 
 
 
1111
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
1112
  attention_mask,
1113
  inputs_embeds=input_tensor,
1114
  past_key_values_length=past_seen_tokens,
 
1115
  ):
1116
  return None
1117
 
1118
  dtype, device = input_tensor.dtype, input_tensor.device
1119
  min_dtype = torch.finfo(dtype).min
1120
  sequence_length = input_tensor.shape[1]
1121
- if hasattr(
1122
- getattr(self.layers[0], "self_attn", {}), "past_key_value"
1123
- ): # static cache
1124
- target_length = self.config.max_position_embeddings
1125
- else: # dynamic cache
1126
  target_length = (
1127
  attention_mask.shape[-1]
1128
  if isinstance(attention_mask, torch.Tensor)
1129
  else past_seen_tokens + sequence_length + 1
1130
  )
1131
 
1132
- causal_mask = torch.full(
1133
- (sequence_length, target_length),
1134
- fill_value=min_dtype,
 
 
1135
  dtype=dtype,
1136
  device=device,
 
 
 
1137
  )
1138
- if sequence_length != 1:
1139
- causal_mask = torch.triu(causal_mask, diagonal=1)
1140
- causal_mask *= torch.arange(
1141
- target_length, device=device
1142
- ) > cache_position.reshape(-1, 1)
1143
- causal_mask = causal_mask[None, None, :, :].expand(
1144
- input_tensor.shape[0], 1, -1, -1
1145
- )
1146
- if attention_mask is not None:
1147
- causal_mask = (
1148
- causal_mask.clone()
1149
- ) # copy to contiguous memory for in-place edit
1150
- if attention_mask.dim() == 2:
1151
- mask_length = attention_mask.shape[-1]
1152
- padding_mask = (
1153
- causal_mask[:, :, :, :mask_length]
1154
- + attention_mask[:, None, None, :]
1155
- )
1156
- padding_mask = padding_mask == 0
1157
- causal_mask[:, :, :, :mask_length] = causal_mask[
1158
- :, :, :, :mask_length
1159
- ].masked_fill(padding_mask, min_dtype)
1160
- elif attention_mask.dim() == 4:
1161
- # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1162
- # cache. In that case, the 4D attention mask attends to the newest tokens only.
1163
- if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1164
- offset = cache_position[0]
1165
- else:
1166
- offset = 0
1167
- mask_shape = attention_mask.shape
1168
- mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1169
- causal_mask[
1170
- : mask_shape[0],
1171
- : mask_shape[1],
1172
- offset : mask_shape[2] + offset,
1173
- : mask_shape[3],
1174
- ] = mask_slice
1175
-
1176
  if (
1177
  self.config._attn_implementation == "sdpa"
1178
  and attention_mask is not None
1179
  and attention_mask.device.type == "cuda"
 
1180
  ):
1181
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1182
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1183
  # Details: https://github.com/pytorch/pytorch/issues/110213
1184
- causal_mask = AttentionMaskConverter._unmask_unattended(
1185
- causal_mask, min_dtype
1186
- )
1187
 
1188
  return causal_mask
1189
 
1190
 
1191
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMA,Llama->Gemma,llama->gemma
1192
  class GemmaForCausalLM(GemmaPreTrainedModel):
1193
  _tied_weights_keys = ["lm_head.weight"]
1194
 
@@ -1219,17 +1023,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1219
  def get_decoder(self):
1220
  return self.model
1221
 
1222
- # Ignore copy
1223
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
1224
- @replace_return_docstrings(
1225
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1226
- )
1227
  def forward(
1228
  self,
1229
  input_ids: torch.LongTensor = None,
1230
  attention_mask: Optional[torch.Tensor] = None,
1231
  position_ids: Optional[torch.LongTensor] = None,
1232
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1233
  inputs_embeds: Optional[torch.FloatTensor] = None,
1234
  labels: Optional[torch.LongTensor] = None,
1235
  use_cache: Optional[bool] = None,
@@ -1263,19 +1064,11 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1263
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1264
  "What is your favorite condiment?"
1265
  ```"""
1266
- output_attentions = (
1267
- output_attentions
1268
- if output_attentions is not None
1269
- else self.config.output_attentions
1270
- )
1271
  output_hidden_states = (
1272
- output_hidden_states
1273
- if output_hidden_states is not None
1274
- else self.config.output_hidden_states
1275
- )
1276
- return_dict = (
1277
- return_dict if return_dict is not None else self.config.use_return_dict
1278
  )
 
1279
 
1280
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1281
  outputs = self.model(
@@ -1326,118 +1119,68 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1326
  attention_mask=None,
1327
  inputs_embeds=None,
1328
  cache_position=None,
 
 
1329
  **kwargs,
1330
  ):
1331
- # With static cache, the `past_key_values` is None
1332
- # TODO joao: standardize interface for the different Cache classes and remove of this if
1333
- has_static_cache = False
1334
- if past_key_values is None:
1335
- past_key_values = getattr(
1336
- getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None
1337
- )
1338
- has_static_cache = past_key_values is not None
1339
-
1340
- past_length = 0
1341
  if past_key_values is not None:
1342
- if isinstance(past_key_values, Cache):
1343
- past_length = (
1344
- cache_position[0]
1345
- if cache_position is not None
1346
- else past_key_values.get_seq_length()
1347
- )
1348
- max_cache_length = (
1349
- torch.tensor(
1350
- past_key_values.get_max_length(), device=input_ids.device
1351
- )
1352
- if past_key_values.get_max_length() is not None
1353
- else None
1354
- )
1355
- cache_length = (
1356
- past_length
1357
- if max_cache_length is None
1358
- else torch.min(max_cache_length, past_length)
1359
- )
1360
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1361
- else:
1362
- cache_length = past_length = past_key_values[0][0].shape[2]
1363
- max_cache_length = None
1364
-
1365
- # Keep only the unprocessed tokens:
1366
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1367
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1368
- # input)
1369
- if (
1370
- attention_mask is not None
1371
- and attention_mask.shape[1] > input_ids.shape[1]
1372
- ):
1373
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1374
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1375
- # input_ids based on the past_length.
1376
- elif past_length < input_ids.shape[1]:
1377
- input_ids = input_ids[:, past_length:]
1378
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1379
-
1380
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1381
- if (
1382
- max_cache_length is not None
1383
- and attention_mask is not None
1384
- and cache_length + input_ids.shape[1] > max_cache_length
1385
- ):
1386
- attention_mask = attention_mask[:, -max_cache_length:]
1387
 
1388
- position_ids = kwargs.get("position_ids", None)
1389
  if attention_mask is not None and position_ids is None:
1390
  # create position_ids on the fly for batch generation
1391
  position_ids = attention_mask.long().cumsum(-1) - 1
1392
  position_ids.masked_fill_(attention_mask == 0, 1)
1393
  if past_key_values:
1394
  position_ids = position_ids[:, -input_ids.shape[1] :]
 
 
1395
 
1396
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1397
- if inputs_embeds is not None and past_key_values is None:
1398
- model_inputs = {"inputs_embeds": inputs_embeds}
1399
  else:
1400
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1401
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1402
- # TODO: use `next_tokens` directly instead.
1403
- model_inputs = {"input_ids": input_ids.contiguous()}
1404
 
1405
- input_length = (
1406
- position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1407
- )
1408
- if cache_position is None:
1409
- cache_position = torch.arange(
1410
- past_length, past_length + input_length, device=input_ids.device
1411
- )
1412
- else:
1413
- cache_position = cache_position[-input_length:]
1414
 
1415
- if has_static_cache:
1416
- past_key_values = None
 
 
 
 
 
 
 
 
 
 
 
1417
 
1418
  model_inputs.update(
1419
  {
1420
  "position_ids": position_ids,
1421
  "cache_position": cache_position,
1422
  "past_key_values": past_key_values,
1423
- "use_cache": kwargs.get("use_cache"),
1424
  "attention_mask": attention_mask,
1425
  }
1426
  )
1427
  return model_inputs
1428
 
1429
- @staticmethod
1430
- def _reorder_cache(past_key_values, beam_idx):
1431
- reordered_past = ()
1432
- for layer_past in past_key_values:
1433
- reordered_past += (
1434
- tuple(
1435
- past_state.index_select(0, beam_idx.to(past_state.device))
1436
- for past_state in layer_past
1437
- ),
1438
- )
1439
- return reordered_past
1440
-
1441
 
1442
  @add_start_docstrings(
1443
  """
@@ -1454,7 +1197,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
1454
  """,
1455
  GEMMA_START_DOCSTRING,
1456
  )
1457
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMA,Llama->Gemma
1458
  class GemmaForSequenceClassification(GemmaPreTrainedModel):
1459
  def __init__(self, config):
1460
  super().__init__(config)
@@ -1477,7 +1219,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1477
  input_ids: torch.LongTensor = None,
1478
  attention_mask: Optional[torch.Tensor] = None,
1479
  position_ids: Optional[torch.LongTensor] = None,
1480
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1481
  inputs_embeds: Optional[torch.FloatTensor] = None,
1482
  labels: Optional[torch.LongTensor] = None,
1483
  use_cache: Optional[bool] = None,
@@ -1491,9 +1233,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1491
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1492
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1493
  """
1494
- return_dict = (
1495
- return_dict if return_dict is not None else self.config.use_return_dict
1496
- )
1497
 
1498
  transformer_outputs = self.model(
1499
  input_ids,
@@ -1515,25 +1255,19 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1515
  batch_size = inputs_embeds.shape[0]
1516
 
1517
  if self.config.pad_token_id is None and batch_size != 1:
1518
- raise ValueError(
1519
- "Cannot handle batch sizes > 1 if no padding token is defined."
1520
- )
1521
  if self.config.pad_token_id is None:
1522
  sequence_lengths = -1
1523
  else:
1524
  if input_ids is not None:
1525
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1526
- sequence_lengths = (
1527
- torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1528
- )
1529
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1530
  sequence_lengths = sequence_lengths.to(logits.device)
1531
  else:
1532
  sequence_lengths = -1
1533
 
1534
- pooled_logits = logits[
1535
- torch.arange(batch_size, device=logits.device), sequence_lengths
1536
- ]
1537
 
1538
  loss = None
1539
  if labels is not None:
@@ -1541,9 +1275,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1541
  if self.config.problem_type is None:
1542
  if self.num_labels == 1:
1543
  self.config.problem_type = "regression"
1544
- elif self.num_labels > 1 and (
1545
- labels.dtype == torch.long or labels.dtype == torch.int
1546
- ):
1547
  self.config.problem_type = "single_label_classification"
1548
  else:
1549
  self.config.problem_type = "multi_label_classification"
@@ -1556,9 +1288,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1556
  loss = loss_fct(pooled_logits, labels)
1557
  elif self.config.problem_type == "single_label_classification":
1558
  loss_fct = CrossEntropyLoss()
1559
- loss = loss_fct(
1560
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1561
- )
1562
  elif self.config.problem_type == "multi_label_classification":
1563
  loss_fct = BCEWithLogitsLoss()
1564
  loss = loss_fct(pooled_logits, labels)
@@ -1573,3 +1303,87 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
1573
  hidden_states=transformer_outputs.hidden_states,
1574
  attentions=transformer_outputs.attentions,
1575
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from <path_to_diff_file.py>.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the diff. If any change should be done, please apply the change to the
5
+ # diff.py file directly.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
  # coding=utf-8
8
  # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
9
  #
 
19
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
  # See the License for the specific language governing permissions and
21
  # limitations under the License.
 
 
22
  import math
 
23
  from typing import List, Optional, Tuple, Union
24
 
25
  import torch
 
26
  import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
+ from ...activations import ACT2FN
31
+ from ...cache_utils import Cache, DynamicCache, StaticCache
32
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
33
+ from ...modeling_flash_attention_utils import _flash_attention_forward
34
+ from ...modeling_outputs import (
 
 
35
  BaseModelOutputWithPast,
36
  CausalLMOutputWithPast,
37
  SequenceClassifierOutputWithPast,
38
+ TokenClassifierOutput,
39
  )
40
+ from ...modeling_utils import PreTrainedModel
41
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
42
+ from ...utils import (
43
  add_start_docstrings,
44
  add_start_docstrings_to_model_forward,
 
45
  is_flash_attn_greater_or_equal_2_10,
46
  logging,
47
  replace_return_docstrings,
48
  )
 
49
  from .configuration_gemma import GemmaConfig
50
 
51
 
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
56
+ def _prepare_4d_causal_attention_mask_with_cache_position(
57
+ attention_mask: torch.Tensor,
58
+ sequence_length: int,
59
+ target_length: int,
60
+ dtype: torch.dtype,
61
+ device: torch.device,
62
+ min_dtype: float,
63
+ cache_position: torch.Tensor,
64
+ batch_size: int,
65
+ ):
66
+ """
67
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
68
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
69
 
70
+ Args:
71
+ attention_mask (`torch.Tensor`):
72
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
73
+ sequence_length (`int`):
74
+ The sequence length being processed.
75
+ target_length (`int`):
76
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
77
+ dtype (`torch.dtype`):
78
+ The dtype to use for the 4D attention mask.
79
+ device (`torch.device`):
80
+ The device to plcae the 4D attention mask on.
81
+ min_dtype (`float`):
82
+ The minimum value representable with the dtype `dtype`.
83
+ cache_position (`torch.Tensor`):
84
+ Indices depicting the position of the input sequence tokens in the sequence.
85
+ batch_size (`torch.Tensor`):
86
+ Batch size.
87
+ """
88
+ if attention_mask is not None and attention_mask.dim() == 4:
89
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
90
+ causal_mask = attention_mask
91
+ else:
92
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
93
+ if sequence_length != 1:
94
+ causal_mask = torch.triu(causal_mask, diagonal=1)
95
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
96
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
97
+ if attention_mask is not None:
98
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
99
+ mask_length = attention_mask.shape[-1]
100
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
101
+ padding_mask = padding_mask == 0
102
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
103
+ padding_mask, min_dtype
104
+ )
105
 
106
+ return causal_mask
107
 
108
 
109
+ class GemmaRMSNorm(nn.Module):
110
+ def __init__(self, dim: int, eps: float = 1e-6):
111
+ super().__init__()
112
+ self.eps = eps
113
+ self.weight = nn.Parameter(torch.zeros(dim))
114
 
115
+ def _norm(self, x):
116
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
117
 
118
+ def forward(self, x):
119
+ output = self._norm(x.float())
120
+ # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
121
+ # See https://github.com/huggingface/transformers/pull/29402
122
+ output = output * (1.0 + self.weight.float())
123
+ return output.type_as(x)
124
 
125
+ def extra_repr(self):
126
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
 
 
 
 
 
 
 
 
 
 
127
 
128
 
129
  ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
 
136
  self.dim = dim
137
  self.max_position_embeddings = max_position_embeddings
138
  self.base = base
139
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
140
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
141
 
142
  @torch.no_grad()
143
  def forward(self, x, position_ids, seq_len=None):
144
  # x: [bs, num_attention_heads, seq_len, head_size]
145
+ self.inv_freq.to(x.device)
146
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
 
 
 
 
 
 
 
 
 
 
 
147
  position_ids_expanded = position_ids[:, None, :].float()
148
  # Force float32 since bfloat16 loses precision on long contexts
149
  # See https://github.com/huggingface/transformers/pull/29285
150
  device_type = x.device.type
151
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
 
 
 
 
152
  with torch.autocast(device_type=device_type, enabled=False):
153
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 
 
154
  emb = torch.cat((freqs, freqs), dim=-1)
155
  cos = emb.cos()
156
  sin = emb.sin()
157
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
158
 
159
 
160
+ class GemmaMLP(nn.Module):
161
+ def __init__(self, config):
162
+ super().__init__()
163
+ self.config = config
164
+ self.hidden_size = config.hidden_size
165
+ self.intermediate_size = config.intermediate_size
166
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
167
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
168
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
169
+ if config.hidden_activation is None:
170
+ logger.warning_once(
171
+ "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
172
+ "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
173
+ "`config.hidden_activation` if you want to override this behaviour.\n"
174
+ "See https://github.com/huggingface/transformers/pull/29402 for more details."
175
+ )
176
+ config.hidden_activation = "gelu_pytorch_tanh"
177
+ hidden_activation = config.hidden_activation
178
+ self.act_fn = ACT2FN[hidden_activation]
179
+
180
+ def forward(self, x):
181
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
182
+
183
+
184
+ class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
185
+ """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
186
+
187
+ def forward(self, x, position_ids):
188
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
189
+ position_ids = position_ids.float() / self.scaling_factor
190
+ cos, sin = super().forward(x, position_ids)
191
+ return cos, sin
192
+
193
+
194
+ class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
195
+ """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
196
+
197
+ def forward(self, x, position_ids):
198
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
199
+ seq_len = torch.max(position_ids) + 1
200
+ if seq_len > self.max_position_embeddings:
201
+ base = self.base * (
202
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
203
+ ) ** (self.dim / (self.dim - 2))
204
+ inv_freq = 1.0 / (
205
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
206
+ )
207
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
208
+
209
+ cos, sin = super().forward(x, position_ids)
210
+ return cos, sin
211
+
212
+
213
  def rotate_half(x):
214
  """Rotates half the hidden dims of the input."""
215
  x1 = x[..., : x.shape[-1] // 2]
 
217
  return torch.cat((-x2, x1), dim=-1)
218
 
219
 
 
220
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
221
  """Applies Rotary Position Embedding to the query and key tensors.
222
 
 
244
  return q_embed, k_embed
245
 
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
248
  """
249
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
252
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
253
  if n_rep == 1:
254
  return hidden_states
255
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
256
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
257
 
258
 
259
  class GemmaAttention(nn.Module):
260
  """Multi-headed attention from 'Attention Is All You Need' paper"""
261
 
 
262
  def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
263
  super().__init__()
264
  self.config = config
 
279
  self.max_position_embeddings = config.max_position_embeddings
280
  self.rope_theta = config.rope_theta
281
  self.is_causal = True
282
+ self.scaling = 1 / math.sqrt(config.head_dim)
283
 
284
  if self.hidden_size % self.num_heads != 0:
285
  raise ValueError(
 
287
  f" and `num_heads`: {self.num_heads})."
288
  )
289
 
290
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
291
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
292
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
293
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
 
 
 
 
 
 
 
 
 
 
 
 
294
  self.rotary_emb = GemmaRotaryEmbedding(
295
  self.head_dim,
296
  max_position_embeddings=self.max_position_embeddings,
 
306
  output_attentions: bool = False,
307
  use_cache: bool = False,
308
  cache_position: Optional[torch.LongTensor] = None,
 
309
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
310
  bsz, q_len, _ = hidden_states.size()
311
 
 
313
  key_states = self.k_proj(hidden_states)
314
  value_states = self.v_proj(hidden_states)
315
 
316
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
317
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
318
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
319
+
320
+ cos, sin = self.rotary_emb(value_states, position_ids)
321
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
 
 
 
 
 
 
 
 
322
 
323
  if past_key_value is not None:
324
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
325
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
326
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
327
 
328
  key_states = repeat_kv(key_states, self.num_key_value_groups)
329
  value_states = repeat_kv(value_states, self.num_key_value_groups)
330
 
331
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
 
 
332
 
333
  if attention_mask is not None: # no matter the length, we just slice it
334
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
335
  attn_weights = attn_weights + causal_mask
336
 
337
  # upcast attention to fp32
338
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
339
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
 
340
  attn_output = torch.matmul(attn_weights, value_states)
341
 
342
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
356
  return attn_output, attn_weights, past_key_value
357
 
358
 
 
359
  class GemmaFlashAttention2(GemmaAttention):
360
  """
361
  Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
 
371
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
372
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
373
 
 
374
  def forward(
375
  self,
376
  hidden_states: torch.Tensor,
 
380
  output_attentions: bool = False,
381
  use_cache: bool = False,
382
  cache_position: Optional[torch.LongTensor] = None,
 
383
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
384
+ if isinstance(past_key_value, StaticCache):
385
+ raise ValueError(
386
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
387
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
388
+ )
389
+
390
  output_attentions = False
391
 
392
  bsz, q_len, _ = hidden_states.size()
 
398
  # Flash attention requires the input to have the shape
399
  # batch_size x seq_length x head_dim x hidden_dim
400
  # therefore we just need to keep the original shape
401
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
402
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
403
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
404
 
405
+ cos, sin = self.rotary_emb(value_states, position_ids)
406
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
407
 
408
  if past_key_value is not None:
409
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
410
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
411
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
412
 
413
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
414
  # to be able to avoid many of these transpose/reshape/view.
 
444
  key_states = key_states.to(target_dtype)
445
  value_states = value_states.to(target_dtype)
446
 
447
+ attn_output = _flash_attention_forward(
448
  query_states,
449
  key_states,
450
  value_states,
451
  attention_mask,
452
  q_len,
453
+ position_ids=position_ids,
454
  dropout=dropout_rate,
455
+ sliding_window=getattr(self, "sliding_window", None),
456
+ is_causal=self.is_causal,
457
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
458
  )
459
 
460
  attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
 
465
 
466
  return attn_output, attn_weights, past_key_value
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  class GemmaSdpaAttention(GemmaAttention):
470
  """
471
  Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
 
473
  SDPA API.
474
  """
475
 
476
+ # Adapted from GemmaAttention.forward
477
  def forward(
478
  self,
479
  hidden_states: torch.Tensor,
 
483
  output_attentions: bool = False,
484
  use_cache: bool = False,
485
  cache_position: Optional[torch.LongTensor] = None,
486
+ **kwargs,
487
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
488
  if output_attentions:
489
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
 
507
  key_states = self.k_proj(hidden_states)
508
  value_states = self.v_proj(hidden_states)
509
 
510
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
511
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
512
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
513
 
514
+ cos, sin = self.rotary_emb(value_states, position_ids)
515
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
516
 
517
  if past_key_value is not None:
518
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
519
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
520
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
521
 
522
  key_states = repeat_kv(key_states, self.num_key_value_groups)
523
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
533
  key_states = key_states.contiguous()
534
  value_states = value_states.contiguous()
535
 
536
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
537
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
538
+ is_causal = True if causal_mask is None and q_len > 1 else False
539
+
540
  attn_output = torch.nn.functional.scaled_dot_product_attention(
541
  query_states,
542
  key_states,
543
  value_states,
544
  attn_mask=causal_mask,
545
  dropout_p=self.attention_dropout if self.training else 0.0,
546
+ is_causal=is_causal,
547
  )
548
 
549
  attn_output = attn_output.transpose(1, 2).contiguous()
 
561
  }
562
 
563
 
 
564
  class GemmaDecoderLayer(nn.Module):
565
  def __init__(self, config: GemmaConfig, layer_idx: int):
566
  super().__init__()
567
  self.hidden_size = config.hidden_size
568
 
569
+ self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
 
 
570
 
571
  self.mlp = GemmaMLP(config)
572
  self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
573
+ self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
574
 
575
  def forward(
576
  self,
577
  hidden_states: torch.Tensor,
578
  attention_mask: Optional[torch.Tensor] = None,
579
  position_ids: Optional[torch.LongTensor] = None,
580
+ past_key_value: Optional[Cache] = None,
581
  output_attentions: Optional[bool] = False,
582
  use_cache: Optional[bool] = False,
583
  cache_position: Optional[torch.LongTensor] = None,
584
  **kwargs,
585
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
586
  """
587
  Args:
588
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
596
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
597
  (see `past_key_values`).
598
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
599
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
600
+ Indices depicting the position of the input sequence tokens in the sequence
601
+ kwargs (`dict`, *optional*):
602
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
603
+ into the model
604
  """
 
 
 
 
 
605
  residual = hidden_states
606
 
607
  hidden_states = self.input_layernorm(hidden_states)
 
661
  config_class = GemmaConfig
662
  base_model_prefix = "model"
663
  supports_gradient_checkpointing = True
 
664
  _no_split_modules = ["GemmaDecoderLayer"]
665
+ _skip_keys_device_placement = ["past_key_values"]
666
  _supports_flash_attn_2 = True
667
  _supports_sdpa = True
668
  _supports_cache_class = True
669
+ _supports_quantized_cache = True
670
+ _supports_static_cache = True
671
 
672
  def _init_weights(self, module):
673
  std = self.config.initializer_range
 
680
  if module.padding_idx is not None:
681
  module.weight.data[module.padding_idx].zero_()
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
+ _CONFIG_FOR_DOC = "GemmaConfig"
 
 
685
 
686
 
687
  GEMMA_INPUTS_DOCSTRING = r"""
 
762
  "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
763
  GEMMA_START_DOCSTRING,
764
  )
 
765
  class GemmaModel(GemmaPreTrainedModel):
766
  """
767
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
 
775
  self.padding_idx = config.pad_token_id
776
  self.vocab_size = config.vocab_size
777
 
778
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
779
  self.layers = nn.ModuleList(
780
+ [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
 
 
781
  )
782
  self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
783
  self.gradient_checkpointing = False
 
792
  self.embed_tokens = value
793
 
794
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
 
795
  def forward(
796
  self,
797
  input_ids: torch.LongTensor = None,
798
  attention_mask: Optional[torch.Tensor] = None,
799
  position_ids: Optional[torch.LongTensor] = None,
800
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
801
  inputs_embeds: Optional[torch.FloatTensor] = None,
802
  use_cache: Optional[bool] = None,
803
  output_attentions: Optional[bool] = None,
 
805
  return_dict: Optional[bool] = None,
806
  cache_position: Optional[torch.LongTensor] = None,
807
  ) -> Union[Tuple, BaseModelOutputWithPast]:
808
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
809
  output_hidden_states = (
810
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
811
  )
812
  use_cache = use_cache if use_cache is not None else self.config.use_cache
813
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
814
 
815
  if (input_ids is None) ^ (inputs_embeds is not None):
816
  raise ValueError(
 
826
  if inputs_embeds is None:
827
  inputs_embeds = self.embed_tokens(input_ids)
828
 
829
+ return_legacy_cache = False # noqa: F841
830
+ if (
831
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
832
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
833
+ return_legacy_cache = True # noqa: F841
834
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
835
 
836
  if cache_position is None:
837
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
838
  cache_position = torch.arange(
839
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
840
  )
841
 
842
  if position_ids is None:
843
  position_ids = cache_position.unsqueeze(0)
844
 
845
  causal_mask = self._update_causal_mask(
846
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
847
  )
848
 
849
  # embed positions
 
852
  # normalized
853
  # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
854
  # See https://github.com/huggingface/transformers/pull/29402
855
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
 
 
856
  hidden_states = hidden_states * normalizer
857
+ if (
858
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
859
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
860
+ return_legacy_cache = True
861
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
862
+ logger.warning_once(
863
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
864
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
865
+ )
866
 
867
  # decoder layers
868
  all_hidden_states = () if output_hidden_states else None
 
909
  if output_hidden_states:
910
  all_hidden_states += (hidden_states,)
911
 
912
+ next_cache = next_decoder_cache if use_cache else None
913
+ if return_legacy_cache:
914
+ next_cache = next_cache.to_legacy_cache()
915
+
 
 
 
916
  if not return_dict:
917
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
918
  return BaseModelOutputWithPast(
919
  last_hidden_state=hidden_states,
920
  past_key_values=next_cache,
 
927
  attention_mask: torch.Tensor,
928
  input_tensor: torch.Tensor,
929
  cache_position: torch.Tensor,
930
+ past_key_values: Cache,
931
+ output_attentions: bool,
932
  ):
933
  # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
934
  # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
 
940
  return attention_mask
941
  return None
942
 
943
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
944
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
945
+ # to infer the attention mask.
946
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
947
+ using_static_cache = isinstance(past_key_values, StaticCache)
948
+
949
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
950
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
951
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
952
  attention_mask,
953
  inputs_embeds=input_tensor,
954
  past_key_values_length=past_seen_tokens,
955
+ is_training=self.training,
956
  ):
957
  return None
958
 
959
  dtype, device = input_tensor.dtype, input_tensor.device
960
  min_dtype = torch.finfo(dtype).min
961
  sequence_length = input_tensor.shape[1]
962
+ if using_static_cache:
963
+ target_length = past_key_values.get_max_length()
964
+ else:
 
 
965
  target_length = (
966
  attention_mask.shape[-1]
967
  if isinstance(attention_mask, torch.Tensor)
968
  else past_seen_tokens + sequence_length + 1
969
  )
970
 
971
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
972
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
973
+ attention_mask,
974
+ sequence_length=sequence_length,
975
+ target_length=target_length,
976
  dtype=dtype,
977
  device=device,
978
+ min_dtype=min_dtype,
979
+ cache_position=cache_position,
980
+ batch_size=input_tensor.shape[0],
981
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982
  if (
983
  self.config._attn_implementation == "sdpa"
984
  and attention_mask is not None
985
  and attention_mask.device.type == "cuda"
986
+ and not output_attentions
987
  ):
988
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
989
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
990
  # Details: https://github.com/pytorch/pytorch/issues/110213
991
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 
 
992
 
993
  return causal_mask
994
 
995
 
 
996
  class GemmaForCausalLM(GemmaPreTrainedModel):
997
  _tied_weights_keys = ["lm_head.weight"]
998
 
 
1023
  def get_decoder(self):
1024
  return self.model
1025
 
 
1026
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
1027
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
1028
  def forward(
1029
  self,
1030
  input_ids: torch.LongTensor = None,
1031
  attention_mask: Optional[torch.Tensor] = None,
1032
  position_ids: Optional[torch.LongTensor] = None,
1033
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1034
  inputs_embeds: Optional[torch.FloatTensor] = None,
1035
  labels: Optional[torch.LongTensor] = None,
1036
  use_cache: Optional[bool] = None,
 
1064
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1065
  "What is your favorite condiment?"
1066
  ```"""
1067
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1068
  output_hidden_states = (
1069
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
1070
  )
1071
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1072
 
1073
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1074
  outputs = self.model(
 
1119
  attention_mask=None,
1120
  inputs_embeds=None,
1121
  cache_position=None,
1122
+ position_ids=None,
1123
+ use_cache=True,
1124
  **kwargs,
1125
  ):
1126
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1127
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1128
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
 
 
 
 
 
 
 
1129
  if past_key_values is not None:
1130
+ if inputs_embeds is not None: # Exception 1
1131
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1132
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1133
+ input_ids = input_ids[:, cache_position]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1134
 
 
1135
  if attention_mask is not None and position_ids is None:
1136
  # create position_ids on the fly for batch generation
1137
  position_ids = attention_mask.long().cumsum(-1) - 1
1138
  position_ids.masked_fill_(attention_mask == 0, 1)
1139
  if past_key_values:
1140
  position_ids = position_ids[:, -input_ids.shape[1] :]
1141
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1142
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1143
 
1144
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1145
+ if inputs_embeds is not None and cache_position[0] == 0:
1146
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1147
  else:
1148
+ # The clone here is for the same reason as for `position_ids`.
1149
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
 
 
1150
 
1151
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1152
+ if model_inputs["inputs_embeds"] is not None:
1153
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1154
+ device = model_inputs["inputs_embeds"].device
1155
+ else:
1156
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1157
+ device = model_inputs["input_ids"].device
 
 
1158
 
1159
+ dtype = self.lm_head.weight.dtype
1160
+ min_dtype = torch.finfo(dtype).min
1161
+
1162
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1163
+ attention_mask,
1164
+ sequence_length=sequence_length,
1165
+ target_length=past_key_values.get_max_length(),
1166
+ dtype=dtype,
1167
+ device=device,
1168
+ min_dtype=min_dtype,
1169
+ cache_position=cache_position,
1170
+ batch_size=batch_size,
1171
+ )
1172
 
1173
  model_inputs.update(
1174
  {
1175
  "position_ids": position_ids,
1176
  "cache_position": cache_position,
1177
  "past_key_values": past_key_values,
1178
+ "use_cache": use_cache,
1179
  "attention_mask": attention_mask,
1180
  }
1181
  )
1182
  return model_inputs
1183
 
 
 
 
 
 
 
 
 
 
 
 
 
1184
 
1185
  @add_start_docstrings(
1186
  """
 
1197
  """,
1198
  GEMMA_START_DOCSTRING,
1199
  )
 
1200
  class GemmaForSequenceClassification(GemmaPreTrainedModel):
1201
  def __init__(self, config):
1202
  super().__init__(config)
 
1219
  input_ids: torch.LongTensor = None,
1220
  attention_mask: Optional[torch.Tensor] = None,
1221
  position_ids: Optional[torch.LongTensor] = None,
1222
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1223
  inputs_embeds: Optional[torch.FloatTensor] = None,
1224
  labels: Optional[torch.LongTensor] = None,
1225
  use_cache: Optional[bool] = None,
 
1233
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1234
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1235
  """
1236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1237
 
1238
  transformer_outputs = self.model(
1239
  input_ids,
 
1255
  batch_size = inputs_embeds.shape[0]
1256
 
1257
  if self.config.pad_token_id is None and batch_size != 1:
1258
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
 
 
1259
  if self.config.pad_token_id is None:
1260
  sequence_lengths = -1
1261
  else:
1262
  if input_ids is not None:
1263
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1264
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
 
 
1265
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1266
  sequence_lengths = sequence_lengths.to(logits.device)
1267
  else:
1268
  sequence_lengths = -1
1269
 
1270
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1271
 
1272
  loss = None
1273
  if labels is not None:
 
1275
  if self.config.problem_type is None:
1276
  if self.num_labels == 1:
1277
  self.config.problem_type = "regression"
1278
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
1279
  self.config.problem_type = "single_label_classification"
1280
  else:
1281
  self.config.problem_type = "multi_label_classification"
 
1288
  loss = loss_fct(pooled_logits, labels)
1289
  elif self.config.problem_type == "single_label_classification":
1290
  loss_fct = CrossEntropyLoss()
1291
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1292
  elif self.config.problem_type == "multi_label_classification":
1293
  loss_fct = BCEWithLogitsLoss()
1294
  loss = loss_fct(pooled_logits, labels)
 
1303
  hidden_states=transformer_outputs.hidden_states,
1304
  attentions=transformer_outputs.attentions,
1305
  )
1306
+
1307
+
1308
+ @add_start_docstrings(
1309
+ """
1310
+ The Gemma Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1311
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1312
+ """,
1313
+ GEMMA_START_DOCSTRING,
1314
+ )
1315
+ class GemmaForTokenClassification(GemmaPreTrainedModel):
1316
+ def __init__(self, config):
1317
+ super().__init__(config)
1318
+ self.num_labels = config.num_labels
1319
+ self.model = GemmaModel(config)
1320
+ if getattr(config, "classifier_dropout", None) is not None:
1321
+ classifier_dropout = config.classifier_dropout
1322
+ elif getattr(config, "hidden_dropout", None) is not None:
1323
+ classifier_dropout = config.hidden_dropout
1324
+ else:
1325
+ classifier_dropout = 0.1
1326
+ self.dropout = nn.Dropout(classifier_dropout)
1327
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1328
+
1329
+ # Initialize weights and apply final processing
1330
+ self.post_init()
1331
+
1332
+ def get_input_embeddings(self):
1333
+ return self.model.embed_tokens
1334
+
1335
+ def set_input_embeddings(self, value):
1336
+ self.model.embed_tokens = value
1337
+
1338
+ @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
1339
+ def forward(
1340
+ self,
1341
+ input_ids: Optional[torch.LongTensor] = None,
1342
+ attention_mask: Optional[torch.Tensor] = None,
1343
+ position_ids: Optional[torch.LongTensor] = None,
1344
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1345
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1346
+ labels: Optional[torch.LongTensor] = None,
1347
+ use_cache: Optional[bool] = None,
1348
+ output_attentions: Optional[bool] = None,
1349
+ output_hidden_states: Optional[bool] = None,
1350
+ return_dict: Optional[bool] = None,
1351
+ ) -> Union[Tuple, TokenClassifierOutput]:
1352
+ r"""
1353
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1354
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1355
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1356
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1357
+ """
1358
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1359
+
1360
+ outputs = self.model(
1361
+ input_ids,
1362
+ attention_mask=attention_mask,
1363
+ position_ids=position_ids,
1364
+ past_key_values=past_key_values,
1365
+ inputs_embeds=inputs_embeds,
1366
+ use_cache=use_cache,
1367
+ output_attentions=output_attentions,
1368
+ output_hidden_states=output_hidden_states,
1369
+ return_dict=return_dict,
1370
+ )
1371
+ sequence_output = outputs[0]
1372
+ sequence_output = self.dropout(sequence_output)
1373
+ logits = self.score(sequence_output)
1374
+
1375
+ loss = None
1376
+ if labels is not None:
1377
+ loss_fct = CrossEntropyLoss()
1378
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1379
+
1380
+ if not return_dict:
1381
+ output = (logits,) + outputs[2:]
1382
+ return ((loss,) + output) if loss is not None else output
1383
+
1384
+ return TokenClassifierOutput(
1385
+ loss=loss,
1386
+ logits=logits,
1387
+ hidden_states=outputs.hidden_states,
1388
+ attentions=outputs.attentions,
1389
+ )