shivanandmn commited on
Commit
083fb87
·
verified ·
1 Parent(s): 7f476c2

Model save

Browse files
Files changed (3) hide show
  1. README.md +79 -0
  2. generation_config.json +7 -0
  3. modeling_rotating_head_gpt2.py +1132 -0
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - generated_from_trainer
5
+ metrics:
6
+ - accuracy
7
+ - bleu
8
+ model-index:
9
+ - name: rotating-head-gp-norm-gpt2-medium-wikitext
10
+ results: []
11
+ ---
12
+
13
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
14
+ should probably proofread and complete it, then remove this comment. -->
15
+
16
+ # rotating-head-gp-norm-gpt2-medium-wikitext
17
+
18
+ This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
19
+ It achieves the following results on the evaluation set:
20
+ - Loss: 3.2113
21
+ - Accuracy: 0.4180
22
+ - Perplexity: 24.8108
23
+ - Bleu: 0.1307
24
+
25
+ ## Model description
26
+
27
+ More information needed
28
+
29
+ ## Intended uses & limitations
30
+
31
+ More information needed
32
+
33
+ ## Training and evaluation data
34
+
35
+ More information needed
36
+
37
+ ## Training procedure
38
+
39
+ ### Training hyperparameters
40
+
41
+ The following hyperparameters were used during training:
42
+ - learning_rate: 0.0001
43
+ - train_batch_size: 64
44
+ - eval_batch_size: 64
45
+ - seed: 42
46
+ - optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
47
+ - lr_scheduler_type: linear
48
+ - lr_scheduler_warmup_ratio: 0.1
49
+ - num_epochs: 5
50
+
51
+ ### Training results
52
+
53
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy | Perplexity | Bleu |
54
+ |:-------------:|:------:|:----:|:---------------:|:--------:|:----------:|:------:|
55
+ | 5.9057 | 0.2806 | 500 | 5.7484 | 0.2234 | 313.6789 | 0.0477 |
56
+ | 4.8613 | 0.5612 | 1000 | 4.7455 | 0.2807 | 115.0632 | 0.0711 |
57
+ | 4.2976 | 0.8418 | 1500 | 4.2220 | 0.3187 | 68.1694 | 0.0837 |
58
+ | 3.9568 | 1.1223 | 2000 | 3.9271 | 0.3461 | 50.7582 | 0.0934 |
59
+ | 3.7919 | 1.4029 | 2500 | 3.7617 | 0.3626 | 43.0211 | 0.0942 |
60
+ | 3.692 | 1.6835 | 3000 | 3.6573 | 0.3725 | 38.7561 | 0.1052 |
61
+ | 3.5939 | 1.9641 | 3500 | 3.5628 | 0.3818 | 35.2616 | 0.1094 |
62
+ | 3.483 | 2.2447 | 4000 | 3.4932 | 0.3879 | 32.8924 | 0.1140 |
63
+ | 3.4251 | 2.5253 | 4500 | 3.4391 | 0.3933 | 31.1583 | 0.1204 |
64
+ | 3.3876 | 2.8058 | 5000 | 3.3855 | 0.3991 | 29.5323 | 0.1227 |
65
+ | 3.2719 | 3.0864 | 5500 | 3.3499 | 0.4020 | 28.5004 | 0.1246 |
66
+ | 3.2612 | 3.3670 | 6000 | 3.3160 | 0.4062 | 27.5488 | 0.1283 |
67
+ | 3.2373 | 3.6476 | 6500 | 3.2848 | 0.4095 | 26.7034 | 0.1288 |
68
+ | 3.2086 | 3.9282 | 7000 | 3.2598 | 0.4118 | 26.0453 | 0.1297 |
69
+ | 3.1402 | 4.2088 | 7500 | 3.2398 | 0.4146 | 25.5281 | 0.1344 |
70
+ | 3.1002 | 4.4893 | 8000 | 3.2246 | 0.4162 | 25.1447 | 0.1317 |
71
+ | 3.1099 | 4.7699 | 8500 | 3.2113 | 0.4180 | 24.8108 | 0.1307 |
72
+
73
+
74
+ ### Framework versions
75
+
76
+ - Transformers 4.49.0
77
+ - Pytorch 2.6.0+cu124
78
+ - Datasets 3.3.2
79
+ - Tokenizers 0.21.0
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.49.0",
6
+ "use_cache": false
7
+ }
modeling_rotating_head_gpt2.py ADDED
@@ -0,0 +1,1132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """PyTorch OpenAI GPT-2 model, code copied from Huggingface"""
3
+
4
+ import math
5
+ import os
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import Callable, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ CausalLMOutputWithCrossAttentions,
21
+ QuestionAnsweringModelOutput,
22
+ SequenceClassifierOutputWithPast,
23
+ TokenClassifierOutput,
24
+ )
25
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary
26
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
27
+ from transformers.utils import (
28
+ ModelOutput,
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
36
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
37
+ from src.models.modeling_gpt2 import GPT2PreTrainedModel, GPT2Block
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ class RotatingHeadGPT2Config(GPT2Config):
42
+ model_type = "rotating-head-gpt2"
43
+ architectures = ["RotatingHeadGPT2LMHeadModel"]
44
+
45
+ class RotatingHeadGPT2PretrainedModel(GPT2PreTrainedModel):
46
+ config_class = RotatingHeadGPT2Config
47
+
48
+
49
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
50
+ """Load tf checkpoints in a pytorch model"""
51
+ try:
52
+ import re
53
+
54
+ import tensorflow as tf
55
+ except ImportError:
56
+ logger.error(
57
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
58
+ "https://www.tensorflow.org/install/ for installation instructions."
59
+ )
60
+ raise
61
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
62
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
63
+ # Load weights from TF model
64
+ init_vars = tf.train.list_variables(tf_path)
65
+ names = []
66
+ arrays = []
67
+ for name, shape in init_vars:
68
+ logger.info(f"Loading TF weight {name} with shape {shape}")
69
+ array = tf.train.load_variable(tf_path, name)
70
+ names.append(name)
71
+ arrays.append(array.squeeze())
72
+
73
+ for name, array in zip(names, arrays):
74
+ name = name[6:] # skip "model/"
75
+ name = name.split("/")
76
+ pointer = model
77
+ for m_name in name:
78
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
79
+ scope_names = re.split(r"(\d+)", m_name)
80
+ else:
81
+ scope_names = [m_name]
82
+ if scope_names[0] == "w" or scope_names[0] == "g":
83
+ pointer = getattr(pointer, "weight")
84
+ elif scope_names[0] == "b":
85
+ pointer = getattr(pointer, "bias")
86
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
87
+ pointer = getattr(pointer, scope_names[0])
88
+ pointer = getattr(pointer, "weight")
89
+ else:
90
+ pointer = getattr(pointer, scope_names[0])
91
+ if len(scope_names) >= 2:
92
+ num = int(scope_names[1])
93
+ pointer = pointer[num]
94
+ try:
95
+ if pointer.shape != array.shape:
96
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
97
+ except ValueError as e:
98
+ e.args += (pointer.shape, array.shape)
99
+ raise
100
+ logger.info(f"Initialize PyTorch weight {name}")
101
+ pointer.data = torch.from_numpy(array)
102
+ return model
103
+
104
+
105
+ def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
106
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
107
+
108
+ if module.scale_attn_weights:
109
+ attn_weights = attn_weights / torch.full(
110
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
111
+ )
112
+
113
+ # Layer-wise attention scaling
114
+ if module.scale_attn_by_inverse_layer_idx:
115
+ attn_weights = attn_weights / float(module.layer_idx + 1)
116
+
117
+ if not module.is_cross_attention:
118
+ # if only "normal" attention layer implements causal mask
119
+ query_length, key_length = query.size(-2), key.size(-2)
120
+ causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
121
+ mask_value = torch.finfo(attn_weights.dtype).min
122
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
123
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
124
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
125
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
126
+
127
+ if attention_mask is not None:
128
+ # Apply the attention mask
129
+ attn_weights = attn_weights + attention_mask
130
+
131
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
132
+
133
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
134
+ attn_weights = attn_weights.type(value.dtype)
135
+ attn_weights = module.attn_dropout(attn_weights)
136
+
137
+ # Mask heads if we want to
138
+ if head_mask is not None:
139
+ attn_weights = attn_weights * head_mask
140
+
141
+ attn_output = torch.matmul(attn_weights, value)
142
+ attn_output = attn_output.transpose(1, 2)
143
+
144
+ return attn_output, attn_weights
145
+
146
+
147
+ class HeadSpecificLRRoPE(nn.Module):
148
+ def __init__(self, num_heads, head_dim):
149
+ super().__init__()
150
+ self.num_heads = num_heads
151
+ self.head_dim = head_dim
152
+
153
+ # Initialize head-specific frequencies (learnable)
154
+ self.frequencies = nn.Parameter(torch.randn(num_heads, head_dim // 2))
155
+ self.layer_norm = nn.LayerNorm(head_dim // 2)
156
+
157
+ def forward(self, Q, K):
158
+ bs, heads, seq, embed = Q.size()
159
+ # Q = torch.einsum('bse,hed->bhse', X, W_Q) # [batch, heads, seq, embed]
160
+ # K = torch.einsum('bse,hed->bhse', X, W_K)
161
+
162
+ positions = torch.arange(seq, device=Q.device).unsqueeze(1) # [seq_length, 1]
163
+
164
+ cos_theta = torch.cos(positions * self.layer_norm(self.frequencies.unsqueeze(1)))
165
+ sin_theta = torch.sin(positions * self.layer_norm(self.frequencies.unsqueeze(1)))
166
+
167
+ Q_even, Q_odd = Q[..., ::2], Q[..., 1::2]
168
+ K_even, K_odd = K[..., ::2], K[..., 1::2]
169
+
170
+ Q_rotated = torch.stack([Q_even * cos_theta - Q_odd * sin_theta,
171
+ Q_even * sin_theta + Q_odd * cos_theta], dim=-1).reshape_as(Q)
172
+ K_rotated = torch.stack([K_even * cos_theta - K_odd * sin_theta,
173
+ K_even * sin_theta + K_odd * cos_theta], dim=-1).reshape_as(K)
174
+
175
+ return Q_rotated, K_rotated
176
+
177
+ class HeadSpecificGPRoPE(nn.Module):
178
+ def __init__(self, num_heads, head_dim, base_frequency=10000):
179
+ super().__init__()
180
+ self.num_heads = num_heads
181
+ self.head_dim = head_dim
182
+
183
+ # Geometric frequency progression (fixed)
184
+ frequency_base = base_frequency ** (-torch.arange(0, head_dim, 2).float() / head_dim)
185
+ scales = torch.logspace(0, -1, steps=num_heads, base=10.0).unsqueeze(1) # [num_heads, 1]
186
+ self.frequencies = (scales @ frequency_base.unsqueeze(0)) # [num_heads, dim//2]
187
+ self.layer_norm = nn.LayerNorm(head_dim)
188
+
189
+ def forward(self, Q, K):
190
+ bs, heads, seq, embed = Q.size()
191
+ # Q = torch.einsum('bse,hed->bhse', X, W_Q) # [batch, heads, seq, embed]
192
+ # K = torch.einsum('bse,hed->bhse', X, W_K)
193
+
194
+ positions = torch.arange(seq, device=Q.device).unsqueeze(1).unsqueeze(0) # [1, seq_length, 1]
195
+
196
+ cos_theta = torch.cos(positions * self.layer_norm(self.frequencies.unsqueeze(1)))
197
+ sin_theta = torch.sin(positions * self.layer_norm(self.frequencies.unsqueeze(1)))
198
+
199
+ Q_even, Q_odd = Q[..., ::2], Q[..., 1::2]
200
+ K_even, K_odd = K[..., ::2], K[..., 1::2]
201
+
202
+ Q_rotated = torch.stack([Q_even * cos_theta - Q_odd * sin_theta,
203
+ Q_even * sin_theta + Q_odd * cos_theta], dim=-1).reshape_as(Q)
204
+ K_rotated = torch.stack([K_even * cos_theta - K_odd * sin_theta,
205
+ K_even * sin_theta + K_odd * cos_theta], dim=-1).reshape_as(K)
206
+
207
+ return Q_rotated, K_rotated
208
+
209
+
210
+
211
+ class GPT2Attention(nn.Module):
212
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
213
+ super().__init__()
214
+ self.config = config
215
+ max_positions = config.max_position_embeddings
216
+ self.register_buffer(
217
+ "bias",
218
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
219
+ 1, 1, max_positions, max_positions
220
+ ),
221
+ persistent=False,
222
+ )
223
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
224
+
225
+ self.embed_dim = config.hidden_size
226
+ self.num_heads = config.num_attention_heads
227
+ self.head_dim = self.embed_dim // self.num_heads
228
+ self.split_size = self.embed_dim
229
+ if self.head_dim * self.num_heads != self.embed_dim:
230
+ raise ValueError(
231
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
232
+ f" {self.num_heads})."
233
+ )
234
+
235
+ self.scale_attn_weights = config.scale_attn_weights
236
+ self.is_cross_attention = is_cross_attention
237
+
238
+ # Layer-wise attention scaling, reordering, and upcasting
239
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
240
+ self.layer_idx = layer_idx
241
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
242
+
243
+ if self.is_cross_attention:
244
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
245
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
246
+ else:
247
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
248
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
249
+
250
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
251
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
252
+ self.is_causal = True
253
+
254
+ self.pruned_heads = set()
255
+
256
+ def prune_heads(self, heads):
257
+ if len(heads) == 0:
258
+ return
259
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
260
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
261
+
262
+ # Prune conv1d layers
263
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
264
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
265
+
266
+ # Update hyper params
267
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
268
+ self.num_heads = self.num_heads - len(heads)
269
+ self.pruned_heads = self.pruned_heads.union(heads)
270
+
271
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
272
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
273
+ bsz, num_heads, q_seq_len, dk = query.size()
274
+ _, _, k_seq_len, _ = key.size()
275
+
276
+ # Preallocate attn_weights for `baddbmm`
277
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
278
+
279
+ # Compute Scale Factor
280
+ scale_factor = 1.0
281
+ if self.scale_attn_weights:
282
+ scale_factor /= float(value.size(-1)) ** 0.5
283
+
284
+ if self.scale_attn_by_inverse_layer_idx:
285
+ scale_factor /= float(self.layer_idx + 1)
286
+
287
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
288
+ with torch.amp.autocast(query.device.type, enabled=False):
289
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
290
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
291
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
292
+
293
+ if not self.is_cross_attention:
294
+ # if only "normal" attention layer implements causal mask
295
+ query_length, key_length = query.size(-2), key.size(-2)
296
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
297
+ mask_value = torch.finfo(attn_weights.dtype).min
298
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
299
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
300
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
301
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
302
+
303
+ if attention_mask is not None:
304
+ # Apply the attention mask
305
+ attn_weights = attn_weights + attention_mask
306
+
307
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
308
+
309
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
310
+ if attn_weights.dtype != torch.float32:
311
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
312
+ attn_weights = attn_weights.type(value.dtype)
313
+ attn_weights = self.attn_dropout(attn_weights)
314
+
315
+ # Mask heads if we want to
316
+ if head_mask is not None:
317
+ attn_weights = attn_weights * head_mask
318
+
319
+ attn_output = torch.matmul(attn_weights, value)
320
+ attn_output = attn_output.transpose(1, 2)
321
+
322
+ return attn_output, attn_weights
323
+
324
+ def forward(
325
+ self,
326
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
327
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
328
+ attention_mask: Optional[torch.FloatTensor] = None,
329
+ head_mask: Optional[torch.FloatTensor] = None,
330
+ encoder_hidden_states: Optional[torch.Tensor] = None,
331
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
332
+ use_cache: Optional[bool] = False,
333
+ output_attentions: Optional[bool] = False,
334
+ **kwargs,
335
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
336
+ if encoder_hidden_states is not None:
337
+ if not hasattr(self, "q_attn"):
338
+ raise ValueError(
339
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
340
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
341
+ )
342
+
343
+ query_states = self.q_attn(hidden_states)
344
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
345
+ attention_mask = encoder_attention_mask
346
+ else:
347
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
348
+
349
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
350
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
351
+
352
+ query_states = query_states.view(shape_q).transpose(1, 2)
353
+ key_states = key_states.view(shape_kv).transpose(1, 2)
354
+ value_states = value_states.view(shape_kv).transpose(1, 2)
355
+
356
+ if layer_past is not None:
357
+ past_key, past_value = layer_past
358
+ key_states = torch.cat((past_key, key_states), dim=-2)
359
+ value_states = torch.cat((past_value, value_states), dim=-2)
360
+
361
+ if use_cache is True:
362
+ present = (key_states, value_states)
363
+ else:
364
+ present = None
365
+
366
+ is_cross_attention = encoder_hidden_states is not None
367
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
368
+
369
+ using_eager = self.config._attn_implementation == "eager"
370
+ attention_interface: Callable = eager_attention_forward
371
+ if self.config._attn_implementation != "eager":
372
+ if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
373
+ using_eager = True
374
+ logger.warning_once(
375
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
376
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
377
+ )
378
+ else:
379
+ # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
380
+ # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
381
+ # not necessarily to eager (if mentionned options are provided).
382
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
383
+
384
+ if using_eager and self.reorder_and_upcast_attn:
385
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
386
+ query_states, key_states, value_states, attention_mask, head_mask
387
+ )
388
+ else:
389
+ attn_output, attn_weights = attention_interface(
390
+ self,
391
+ query_states,
392
+ key_states,
393
+ value_states,
394
+ attention_mask,
395
+ head_mask=head_mask,
396
+ dropout=self.attn_dropout.p if self.training else 0.0,
397
+ is_causal=is_causal,
398
+ **kwargs,
399
+ )
400
+
401
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
402
+ attn_output = self.c_proj(attn_output)
403
+ attn_output = self.resid_dropout(attn_output)
404
+
405
+ outputs = (attn_output, present)
406
+ if output_attentions:
407
+ outputs += (attn_weights,)
408
+
409
+ return outputs # a, present, (attentions)
410
+
411
+ class RotatingheadGPT2Attention(GPT2Attention):
412
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
413
+ super().__init__(config, is_cross_attention, layer_idx)
414
+ if config.rotatinghead == 'lr':
415
+ self.rope = HeadSpecificLRRoPE(config.num_attention_heads, self.head_dim)
416
+ elif config.rotatinghead == 'gp':
417
+ self.rope = HeadSpecificGPRoPE(config.num_attention_heads, self.head_dim)
418
+ self.rotatinghead = config.rotatinghead
419
+
420
+ def forward(
421
+ self,
422
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
423
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
424
+ attention_mask: Optional[torch.FloatTensor] = None,
425
+ head_mask: Optional[torch.FloatTensor] = None,
426
+ encoder_hidden_states: Optional[torch.Tensor] = None,
427
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
428
+ use_cache: Optional[bool] = False,
429
+ output_attentions: Optional[bool] = False,
430
+ **kwargs,
431
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
432
+ if encoder_hidden_states is not None:
433
+ if not hasattr(self, "q_attn"):
434
+ raise ValueError(
435
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
436
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
437
+ )
438
+
439
+ query_states = self.q_attn(hidden_states)
440
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
441
+ attention_mask = encoder_attention_mask
442
+ else:
443
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
444
+
445
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
446
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
447
+
448
+ query_states = query_states.view(shape_q).transpose(1, 2)
449
+ key_states = key_states.view(shape_kv).transpose(1, 2)
450
+ value_states = value_states.view(shape_kv).transpose(1, 2)
451
+
452
+ if layer_past is not None:
453
+ past_key, past_value = layer_past
454
+ key_states = torch.cat((past_key, key_states), dim=-2)
455
+ value_states = torch.cat((past_value, value_states), dim=-2)
456
+
457
+ if use_cache is True:
458
+ present = (key_states, value_states)
459
+ else:
460
+ present = None
461
+
462
+ is_cross_attention = encoder_hidden_states is not None
463
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
464
+
465
+ using_eager = self.config._attn_implementation == "eager"
466
+ attention_interface: Callable = eager_attention_forward
467
+ if self.config._attn_implementation != "eager":
468
+ if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
469
+ using_eager = True
470
+ logger.warning_once(
471
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
472
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
473
+ )
474
+ else:
475
+ # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
476
+ # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
477
+ # not necessarily to eager (if mentionned options are provided).
478
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
479
+
480
+ query_states, key_states = self.rope(query_states, key_states)
481
+
482
+ if using_eager and self.reorder_and_upcast_attn:
483
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
484
+ query_states, key_states, value_states, attention_mask, head_mask
485
+ )
486
+ else:
487
+ attn_output, attn_weights = attention_interface(
488
+ self,
489
+ query_states,
490
+ key_states,
491
+ value_states,
492
+ attention_mask,
493
+ head_mask=head_mask,
494
+ dropout=self.attn_dropout.p if self.training else 0.0,
495
+ is_causal=is_causal,
496
+ **kwargs,
497
+ )
498
+
499
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
500
+ attn_output = self.c_proj(attn_output)
501
+ attn_output = self.resid_dropout(attn_output)
502
+
503
+ outputs = (attn_output, present)
504
+ if output_attentions:
505
+ outputs += (attn_weights,)
506
+
507
+ return outputs # a, present, (attentions)
508
+
509
+
510
+ class GPT2MLP(nn.Module):
511
+ def __init__(self, intermediate_size, config):
512
+ super().__init__()
513
+ embed_dim = config.hidden_size
514
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
515
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
516
+ self.act = ACT2FN[config.activation_function]
517
+ self.dropout = nn.Dropout(config.resid_pdrop)
518
+
519
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
520
+ hidden_states = self.c_fc(hidden_states)
521
+ hidden_states = self.act(hidden_states)
522
+ hidden_states = self.c_proj(hidden_states)
523
+ hidden_states = self.dropout(hidden_states)
524
+ return hidden_states
525
+
526
+
527
+ class GPT2Block(nn.Module):
528
+ def __init__(self, config, layer_idx=None):
529
+ super().__init__()
530
+ hidden_size = config.hidden_size
531
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
532
+
533
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
534
+ if config.rotatinghead is not None:
535
+ self.attn = RotatingheadGPT2Attention(config, layer_idx=layer_idx)
536
+ else:
537
+ self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
538
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
539
+
540
+ if config.add_cross_attention:
541
+ self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
542
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
543
+
544
+ self.mlp = GPT2MLP(inner_dim, config)
545
+
546
+ def forward(
547
+ self,
548
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
549
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
550
+ attention_mask: Optional[torch.FloatTensor] = None,
551
+ head_mask: Optional[torch.FloatTensor] = None,
552
+ encoder_hidden_states: Optional[torch.Tensor] = None,
553
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
554
+ use_cache: Optional[bool] = False,
555
+ output_attentions: Optional[bool] = False,
556
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
557
+ residual = hidden_states
558
+ hidden_states = self.ln_1(hidden_states)
559
+ attn_outputs = self.attn(
560
+ hidden_states,
561
+ layer_past=layer_past,
562
+ attention_mask=attention_mask,
563
+ head_mask=head_mask,
564
+ use_cache=use_cache,
565
+ output_attentions=output_attentions,
566
+ )
567
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
568
+ outputs = attn_outputs[1:]
569
+ # residual connection
570
+ hidden_states = attn_output + residual
571
+
572
+ if encoder_hidden_states is not None:
573
+ # add one self-attention block for cross-attention
574
+ if not hasattr(self, "crossattention"):
575
+ raise ValueError(
576
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
577
+ "cross-attention layers by setting `config.add_cross_attention=True`"
578
+ )
579
+ residual = hidden_states
580
+ hidden_states = self.ln_cross_attn(hidden_states)
581
+ cross_attn_outputs = self.crossattention(
582
+ hidden_states,
583
+ attention_mask=attention_mask,
584
+ head_mask=head_mask,
585
+ encoder_hidden_states=encoder_hidden_states,
586
+ encoder_attention_mask=encoder_attention_mask,
587
+ output_attentions=output_attentions,
588
+ )
589
+ attn_output = cross_attn_outputs[0]
590
+ # residual connection
591
+ hidden_states = residual + attn_output
592
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
593
+
594
+ residual = hidden_states
595
+ hidden_states = self.ln_2(hidden_states)
596
+ feed_forward_hidden_states = self.mlp(hidden_states)
597
+ # residual connection
598
+ hidden_states = residual + feed_forward_hidden_states
599
+
600
+ if use_cache:
601
+ outputs = (hidden_states,) + outputs
602
+ else:
603
+ outputs = (hidden_states,) + outputs[1:]
604
+
605
+ return outputs # hidden_states, present, (attentions, cross_attentions)
606
+
607
+
608
+ class RotatingHeadGPT2PretrainedModel(PreTrainedModel):
609
+ """
610
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
611
+ models.
612
+ """
613
+
614
+ config_class = RotatingHeadGPT2Config
615
+ load_tf_weights = load_tf_weights_in_gpt2
616
+ base_model_prefix = "transformer"
617
+ is_parallelizable = True
618
+ supports_gradient_checkpointing = True
619
+ _no_split_modules = ["GPT2Block"]
620
+ _skip_keys_device_placement = "past_key_values"
621
+ _supports_flash_attn_2 = True
622
+ _supports_sdpa = True
623
+
624
+ def __init__(self, *inputs, **kwargs):
625
+ super().__init__(*inputs, **kwargs)
626
+
627
+ def _init_weights(self, module):
628
+ """Initialize the weights."""
629
+ if isinstance(module, (nn.Linear, Conv1D)):
630
+ # Slightly different from the TF version which uses truncated_normal for initialization
631
+ # cf https://github.com/pytorch/pytorch/pull/5617
632
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
633
+ if module.bias is not None:
634
+ module.bias.data.zero_()
635
+ elif isinstance(module, nn.Embedding):
636
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
637
+ if module.padding_idx is not None:
638
+ module.weight.data[module.padding_idx].zero_()
639
+ elif isinstance(module, nn.LayerNorm):
640
+ module.bias.data.zero_()
641
+ module.weight.data.fill_(1.0)
642
+
643
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
644
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
645
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
646
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
647
+ #
648
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
649
+ for name, p in module.named_parameters():
650
+ if name == "c_proj.weight":
651
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
652
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
653
+
654
+
655
+ @dataclass
656
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
657
+ """
658
+ Base class for outputs of models predicting if two sentences are consecutive or not.
659
+
660
+ Args:
661
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
662
+ Language modeling loss.
663
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
664
+ Multiple choice classification loss.
665
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
666
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
667
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
668
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
669
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
670
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
671
+ sequence_length, embed_size_per_head)`).
672
+
673
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
674
+ `past_key_values` input) to speed up sequential decoding.
675
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
676
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
677
+ shape `(batch_size, sequence_length, hidden_size)`.
678
+
679
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
680
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
681
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
682
+ sequence_length)`.
683
+
684
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
685
+ self-attention heads.
686
+ """
687
+
688
+ loss: Optional[torch.FloatTensor] = None
689
+ mc_loss: Optional[torch.FloatTensor] = None
690
+ logits: torch.FloatTensor = None
691
+ mc_logits: torch.FloatTensor = None
692
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
693
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
694
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
695
+
696
+
697
+
698
+ class RotatingHeadGPT2Model(RotatingHeadGPT2PretrainedModel):
699
+ _supports_param_buffer_assignment = False
700
+
701
+ def __init__(self, config):
702
+ super().__init__(config)
703
+
704
+ self.embed_dim = config.hidden_size
705
+
706
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
707
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
708
+
709
+ self.drop = nn.Dropout(config.embd_pdrop)
710
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
711
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
712
+
713
+ # Model parallel
714
+ self.model_parallel = False
715
+ self.device_map = None
716
+ self.gradient_checkpointing = False
717
+ self._attn_implementation = config._attn_implementation
718
+
719
+ # Initialize weights and apply final processing
720
+ self.post_init()
721
+
722
+ def parallelize(self, device_map=None):
723
+ # Check validity of device_map
724
+ warnings.warn(
725
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
726
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
727
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
728
+ " ...}",
729
+ FutureWarning,
730
+ )
731
+ self.device_map = (
732
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
733
+ )
734
+ assert_device_map(self.device_map, len(self.h))
735
+ self.model_parallel = True
736
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
737
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
738
+ self.wte = self.wte.to(self.first_device)
739
+ self.wpe = self.wpe.to(self.first_device)
740
+ # Load onto devices
741
+ for k, v in self.device_map.items():
742
+ for block in v:
743
+ cuda_device = "cuda:" + str(k)
744
+ self.h[block] = self.h[block].to(cuda_device)
745
+ # ln_f to last
746
+ self.ln_f = self.ln_f.to(self.last_device)
747
+
748
+ def deparallelize(self):
749
+ warnings.warn(
750
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
751
+ FutureWarning,
752
+ )
753
+ self.model_parallel = False
754
+ self.device_map = None
755
+ self.first_device = "cpu"
756
+ self.last_device = "cpu"
757
+ self.wte = self.wte.to("cpu")
758
+ self.wpe = self.wpe.to("cpu")
759
+ for index in range(len(self.h)):
760
+ self.h[index] = self.h[index].to("cpu")
761
+ self.ln_f = self.ln_f.to("cpu")
762
+ torch.cuda.empty_cache()
763
+
764
+ def get_input_embeddings(self):
765
+ return self.wte
766
+
767
+ def set_input_embeddings(self, new_embeddings):
768
+ self.wte = new_embeddings
769
+
770
+ def _prune_heads(self, heads_to_prune):
771
+ """
772
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
773
+ """
774
+ for layer, heads in heads_to_prune.items():
775
+ self.h[layer].attn.prune_heads(heads)
776
+
777
+ def forward(
778
+ self,
779
+ input_ids: Optional[torch.LongTensor] = None,
780
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
781
+ attention_mask: Optional[torch.FloatTensor] = None,
782
+ token_type_ids: Optional[torch.LongTensor] = None,
783
+ position_ids: Optional[torch.LongTensor] = None,
784
+ head_mask: Optional[torch.FloatTensor] = None,
785
+ inputs_embeds: Optional[torch.FloatTensor] = None,
786
+ encoder_hidden_states: Optional[torch.Tensor] = None,
787
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
788
+ use_cache: Optional[bool] = None,
789
+ output_attentions: Optional[bool] = None,
790
+ output_hidden_states: Optional[bool] = None,
791
+ return_dict: Optional[bool] = None,
792
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
793
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
794
+ output_hidden_states = (
795
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
796
+ )
797
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
798
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
799
+
800
+ if input_ids is not None and inputs_embeds is not None:
801
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
802
+ elif input_ids is not None:
803
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
804
+ input_shape = input_ids.size()
805
+ input_ids = input_ids.view(-1, input_shape[-1])
806
+ batch_size = input_ids.shape[0]
807
+ elif inputs_embeds is not None:
808
+ input_shape = inputs_embeds.size()[:-1]
809
+ batch_size = inputs_embeds.shape[0]
810
+ else:
811
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
812
+
813
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
814
+
815
+ if token_type_ids is not None:
816
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
817
+
818
+ if past_key_values is None:
819
+ past_length = 0
820
+ past_key_values = tuple([None] * len(self.h))
821
+ else:
822
+ past_length = past_key_values[0][0].size(-2)
823
+ if position_ids is None:
824
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
825
+ position_ids = position_ids.unsqueeze(0)
826
+
827
+ if inputs_embeds is None:
828
+ inputs_embeds = self.wte(input_ids)
829
+ position_embeds = self.wpe(position_ids)
830
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
831
+
832
+ # Attention mask.
833
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
834
+ attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
835
+ if self._attn_implementation == "flash_attention_2":
836
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
837
+ elif _use_sdpa:
838
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
839
+ attention_mask=attention_mask,
840
+ input_shape=(batch_size, input_shape[-1]),
841
+ inputs_embeds=inputs_embeds,
842
+ past_key_values_length=past_length,
843
+ )
844
+ else:
845
+ if attention_mask is not None:
846
+ # We create a 3D attention mask from a 2D tensor mask.
847
+ # Sizes are [batch_size, 1, 1, to_seq_length]
848
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
849
+ # this attention mask is more simple than the triangular masking of causal attention
850
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
851
+ attention_mask = attention_mask[:, None, None, :]
852
+
853
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
854
+ # masked positions, this operation will create a tensor which is 0.0 for
855
+ # positions we want to attend and the dtype's smallest value for masked positions.
856
+ # Since we are adding it to the raw scores before the softmax, this is
857
+ # effectively the same as removing these entirely.
858
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
859
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
860
+
861
+ # If a 2D or 3D attention mask is provided for the cross-attention
862
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
863
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
864
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
865
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
866
+ if encoder_attention_mask is None:
867
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
868
+ if _use_sdpa:
869
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
870
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
871
+ )
872
+ elif not self._attn_implementation == "flash_attention_2":
873
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
874
+ else:
875
+ encoder_attention_mask = None
876
+
877
+ # Prepare head mask if needed
878
+ # 1.0 in head_mask indicate we keep the head
879
+ # attention_probs has shape bsz x n_heads x N x N
880
+ # head_mask has shape n_layer x batch x n_heads x N x N
881
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
882
+
883
+ if token_type_ids is not None:
884
+ token_type_embeds = self.wte(token_type_ids)
885
+ hidden_states = hidden_states + token_type_embeds
886
+
887
+ hidden_states = self.drop(hidden_states)
888
+
889
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
890
+
891
+ if self.gradient_checkpointing and self.training:
892
+ if use_cache:
893
+ logger.warning_once(
894
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
895
+ )
896
+ use_cache = False
897
+
898
+ presents = () if use_cache else None
899
+ all_self_attentions = () if output_attentions else None
900
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
901
+ all_hidden_states = () if output_hidden_states else None
902
+ for i in range(len(self.h)):
903
+ block, layer_past = self.h[i], past_key_values[i]
904
+ # Model parallel
905
+ if self.model_parallel:
906
+ torch.cuda.set_device(hidden_states.device)
907
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
908
+ if layer_past is not None:
909
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
910
+ # Ensure that attention_mask is always on the same device as hidden_states
911
+ if attention_mask is not None:
912
+ attention_mask = attention_mask.to(hidden_states.device)
913
+ if isinstance(head_mask, torch.Tensor):
914
+ head_mask = head_mask.to(hidden_states.device)
915
+ if output_hidden_states:
916
+ all_hidden_states = all_hidden_states + (hidden_states,)
917
+
918
+ if self.gradient_checkpointing and self.training:
919
+ outputs = self._gradient_checkpointing_func(
920
+ block.__call__,
921
+ hidden_states,
922
+ None,
923
+ attention_mask,
924
+ head_mask[i],
925
+ encoder_hidden_states,
926
+ encoder_attention_mask,
927
+ use_cache,
928
+ output_attentions,
929
+ )
930
+ else:
931
+ outputs = block(
932
+ hidden_states,
933
+ layer_past=layer_past,
934
+ attention_mask=attention_mask,
935
+ head_mask=head_mask[i],
936
+ encoder_hidden_states=encoder_hidden_states,
937
+ encoder_attention_mask=encoder_attention_mask,
938
+ use_cache=use_cache,
939
+ output_attentions=output_attentions,
940
+ )
941
+
942
+ hidden_states = outputs[0]
943
+ if use_cache is True:
944
+ presents = presents + (outputs[1],)
945
+
946
+ if output_attentions:
947
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
948
+ if self.config.add_cross_attention:
949
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
950
+
951
+ # Model Parallel: If it's the last layer for that device, put things on the next device
952
+ if self.model_parallel:
953
+ for k, v in self.device_map.items():
954
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
955
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
956
+
957
+ hidden_states = self.ln_f(hidden_states)
958
+
959
+ hidden_states = hidden_states.view(output_shape)
960
+ # Add last hidden state
961
+ if output_hidden_states:
962
+ all_hidden_states = all_hidden_states + (hidden_states,)
963
+
964
+ if not return_dict:
965
+ return tuple(
966
+ v
967
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
968
+ if v is not None
969
+ )
970
+
971
+ return BaseModelOutputWithPastAndCrossAttentions(
972
+ last_hidden_state=hidden_states,
973
+ past_key_values=presents,
974
+ hidden_states=all_hidden_states,
975
+ attentions=all_self_attentions,
976
+ cross_attentions=all_cross_attentions,
977
+ )
978
+
979
+
980
+ class RotatingHeadGPT2LMHeadModel(RotatingHeadGPT2PretrainedModel, GenerationMixin):
981
+ _tied_weights_keys = ["lm_head.weight"]
982
+
983
+ def __init__(self, config):
984
+ super().__init__(config)
985
+ self.transformer = RotatingHeadGPT2Model(config)
986
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
987
+
988
+ # Model parallel
989
+ self.model_parallel = False
990
+ self.device_map = None
991
+
992
+ # Initialize weights and apply final processing
993
+ self.post_init()
994
+
995
+ def parallelize(self, device_map=None):
996
+ warnings.warn(
997
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
998
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
999
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1000
+ " 0, 'transformer.h.1': 1, ...}",
1001
+ FutureWarning,
1002
+ )
1003
+ self.device_map = (
1004
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1005
+ if device_map is None
1006
+ else device_map
1007
+ )
1008
+ assert_device_map(self.device_map, len(self.transformer.h))
1009
+ self.transformer.parallelize(self.device_map)
1010
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1011
+ self.model_parallel = True
1012
+
1013
+ def deparallelize(self):
1014
+ warnings.warn(
1015
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1016
+ FutureWarning,
1017
+ )
1018
+ self.transformer.deparallelize()
1019
+ self.transformer = self.transformer.to("cpu")
1020
+ self.lm_head = self.lm_head.to("cpu")
1021
+ self.model_parallel = False
1022
+ torch.cuda.empty_cache()
1023
+
1024
+ def get_output_embeddings(self):
1025
+ return self.lm_head
1026
+
1027
+ def set_output_embeddings(self, new_embeddings):
1028
+ self.lm_head = new_embeddings
1029
+
1030
+ def forward(
1031
+ self,
1032
+ input_ids: Optional[torch.LongTensor] = None,
1033
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1034
+ attention_mask: Optional[torch.FloatTensor] = None,
1035
+ token_type_ids: Optional[torch.LongTensor] = None,
1036
+ position_ids: Optional[torch.LongTensor] = None,
1037
+ head_mask: Optional[torch.FloatTensor] = None,
1038
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1039
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1040
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1041
+ labels: Optional[torch.LongTensor] = None,
1042
+ use_cache: Optional[bool] = None,
1043
+ output_attentions: Optional[bool] = None,
1044
+ output_hidden_states: Optional[bool] = None,
1045
+ return_dict: Optional[bool] = None,
1046
+ **kwargs,
1047
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1048
+ r"""
1049
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1050
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1051
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1052
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1053
+ """
1054
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1055
+
1056
+ transformer_outputs = self.transformer(
1057
+ input_ids,
1058
+ past_key_values=past_key_values,
1059
+ attention_mask=attention_mask,
1060
+ token_type_ids=token_type_ids,
1061
+ position_ids=position_ids,
1062
+ head_mask=head_mask,
1063
+ inputs_embeds=inputs_embeds,
1064
+ encoder_hidden_states=encoder_hidden_states,
1065
+ encoder_attention_mask=encoder_attention_mask,
1066
+ use_cache=use_cache,
1067
+ output_attentions=output_attentions,
1068
+ output_hidden_states=output_hidden_states,
1069
+ return_dict=return_dict,
1070
+ )
1071
+ hidden_states = transformer_outputs[0]
1072
+
1073
+ # Set device for model parallelism
1074
+ if self.model_parallel:
1075
+ torch.cuda.set_device(self.transformer.first_device)
1076
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1077
+
1078
+ lm_logits = self.lm_head(hidden_states)
1079
+
1080
+ loss = None
1081
+ if labels is not None:
1082
+ # Flatten the tokens
1083
+ loss = self.loss_function(
1084
+ lm_logits,
1085
+ labels,
1086
+ vocab_size=self.config.vocab_size,
1087
+ **kwargs,
1088
+ )
1089
+
1090
+ if not return_dict:
1091
+ output = (lm_logits,) + transformer_outputs[1:]
1092
+ return ((loss,) + output) if loss is not None else output
1093
+
1094
+ return CausalLMOutputWithCrossAttentions(
1095
+ loss=loss,
1096
+ logits=lm_logits,
1097
+ past_key_values=transformer_outputs.past_key_values,
1098
+ hidden_states=transformer_outputs.hidden_states,
1099
+ attentions=transformer_outputs.attentions,
1100
+ cross_attentions=transformer_outputs.cross_attentions,
1101
+ )
1102
+
1103
+ @staticmethod
1104
+ def _reorder_cache(
1105
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1106
+ ) -> Tuple[Tuple[torch.Tensor]]:
1107
+ """
1108
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1109
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1110
+ beam_idx at every generation step.
1111
+ """
1112
+ return tuple(
1113
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1114
+ for layer_past in past_key_values
1115
+ )
1116
+
1117
+ __all__ = [
1118
+ "RotatingHeadGPT2LMHeadModel",
1119
+ "RotatingHeadGPT2Model",
1120
+ "RotatingHeadGPT2PretrainedModel",
1121
+ "load_tf_weights_in_gpt2",
1122
+ ]
1123
+
1124
+
1125
+ if __name__ == "__main__":
1126
+ cg = GPT2Config.from_pretrained("gpt2-medium")
1127
+ cg.rotatinghead = 'gp'
1128
+ model = RotatingHeadGPT2LMHeadModel(cg)
1129
+ from src.utils.model_utlis import print_trainable_parameters
1130
+ print_trainable_parameters(model)
1131
+ model(torch.randint(0, 10000, (1, 100)))
1132
+ print()