shivanandmn commited on
Commit
56afc7d
·
verified ·
1 Parent(s): 0d3e238

Model save

Browse files
Files changed (3) hide show
  1. README.md +79 -0
  2. generation_config.json +7 -0
  3. modeling_dd_gpt2.py +1109 -0
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - generated_from_trainer
5
+ metrics:
6
+ - accuracy
7
+ - bleu
8
+ model-index:
9
+ - name: dd-gpt2-medium-wikitext
10
+ results: []
11
+ ---
12
+
13
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
14
+ should probably proofread and complete it, then remove this comment. -->
15
+
16
+ # dd-gpt2-medium-wikitext
17
+
18
+ This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
19
+ It achieves the following results on the evaluation set:
20
+ - Loss: 3.3729
21
+ - Accuracy: 0.4006
22
+ - Perplexity: 29.1627
23
+ - Bleu: 0.1356
24
+
25
+ ## Model description
26
+
27
+ More information needed
28
+
29
+ ## Intended uses & limitations
30
+
31
+ More information needed
32
+
33
+ ## Training and evaluation data
34
+
35
+ More information needed
36
+
37
+ ## Training procedure
38
+
39
+ ### Training hyperparameters
40
+
41
+ The following hyperparameters were used during training:
42
+ - learning_rate: 0.0001
43
+ - train_batch_size: 64
44
+ - eval_batch_size: 64
45
+ - seed: 42
46
+ - optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
47
+ - lr_scheduler_type: linear
48
+ - lr_scheduler_warmup_ratio: 0.1
49
+ - num_epochs: 5
50
+
51
+ ### Training results
52
+
53
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy | Perplexity | Bleu |
54
+ |:-------------:|:------:|:----:|:---------------:|:--------:|:----------:|:------:|
55
+ | 6.3499 | 0.2806 | 500 | 6.2328 | 0.1688 | 509.1785 | 0.0261 |
56
+ | 5.4979 | 0.5612 | 1000 | 5.3734 | 0.2228 | 215.6041 | 0.0506 |
57
+ | 4.8996 | 0.8418 | 1500 | 4.7975 | 0.2650 | 121.2067 | 0.0669 |
58
+ | 4.5102 | 1.1223 | 2000 | 4.4042 | 0.2992 | 81.7968 | 0.0791 |
59
+ | 4.2029 | 1.4029 | 2500 | 4.1110 | 0.3301 | 61.0070 | 0.0887 |
60
+ | 4.0332 | 1.6835 | 3000 | 3.9383 | 0.3457 | 51.3319 | 0.0996 |
61
+ | 3.8911 | 1.9641 | 3500 | 3.8146 | 0.3575 | 45.3566 | 0.1107 |
62
+ | 3.7698 | 2.2447 | 4000 | 3.7189 | 0.3663 | 41.2194 | 0.1154 |
63
+ | 3.6812 | 2.5253 | 4500 | 3.6449 | 0.3729 | 38.2808 | 0.1225 |
64
+ | 3.63 | 2.8058 | 5000 | 3.5815 | 0.3790 | 35.9274 | 0.1216 |
65
+ | 3.5287 | 3.0864 | 5500 | 3.5309 | 0.3840 | 34.1532 | 0.1261 |
66
+ | 3.5032 | 3.3670 | 6000 | 3.4913 | 0.3883 | 32.8286 | 0.1302 |
67
+ | 3.4684 | 3.6476 | 6500 | 3.4542 | 0.3917 | 31.6327 | 0.1304 |
68
+ | 3.4365 | 3.9282 | 7000 | 3.4250 | 0.3949 | 30.7240 | 0.1303 |
69
+ | 3.3894 | 4.2088 | 7500 | 3.4020 | 0.3973 | 30.0227 | 0.1327 |
70
+ | 3.3446 | 4.4893 | 8000 | 3.3850 | 0.3992 | 29.5189 | 0.1336 |
71
+ | 3.3532 | 4.7699 | 8500 | 3.3729 | 0.4006 | 29.1627 | 0.1356 |
72
+
73
+
74
+ ### Framework versions
75
+
76
+ - Transformers 4.49.0
77
+ - Pytorch 2.6.0+cu124
78
+ - Datasets 3.3.2
79
+ - Tokenizers 0.21.0
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.49.0",
6
+ "use_cache": false
7
+ }
modeling_dd_gpt2.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """PyTorch OpenAI GPT-2 model, code copied from Huggingface"""
3
+
4
+ import math
5
+ import os
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import Callable, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ CausalLMOutputWithCrossAttentions,
21
+ QuestionAnsweringModelOutput,
22
+ SequenceClassifierOutputWithPast,
23
+ TokenClassifierOutput,
24
+ )
25
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary
26
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
27
+ from transformers.utils import (
28
+ ModelOutput,
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
36
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
37
+ from src.models.modeling_gpt2 import GPT2PreTrainedModel, GPT2Block
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ class DDGPT2Config(GPT2Config):
42
+ model_type = "dd-gpt2"
43
+ architectures = ["DDGPT2LMHeadModel"]
44
+
45
+ class DDGPT2PretrainedModel(GPT2PreTrainedModel):
46
+ config_class = DDGPT2Config
47
+
48
+
49
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
50
+ """Load tf checkpoints in a pytorch model"""
51
+ try:
52
+ import re
53
+
54
+ import tensorflow as tf
55
+ except ImportError:
56
+ logger.error(
57
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
58
+ "https://www.tensorflow.org/install/ for installation instructions."
59
+ )
60
+ raise
61
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
62
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
63
+ # Load weights from TF model
64
+ init_vars = tf.train.list_variables(tf_path)
65
+ names = []
66
+ arrays = []
67
+ for name, shape in init_vars:
68
+ logger.info(f"Loading TF weight {name} with shape {shape}")
69
+ array = tf.train.load_variable(tf_path, name)
70
+ names.append(name)
71
+ arrays.append(array.squeeze())
72
+
73
+ for name, array in zip(names, arrays):
74
+ name = name[6:] # skip "model/"
75
+ name = name.split("/")
76
+ pointer = model
77
+ for m_name in name:
78
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
79
+ scope_names = re.split(r"(\d+)", m_name)
80
+ else:
81
+ scope_names = [m_name]
82
+ if scope_names[0] == "w" or scope_names[0] == "g":
83
+ pointer = getattr(pointer, "weight")
84
+ elif scope_names[0] == "b":
85
+ pointer = getattr(pointer, "bias")
86
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
87
+ pointer = getattr(pointer, scope_names[0])
88
+ pointer = getattr(pointer, "weight")
89
+ else:
90
+ pointer = getattr(pointer, scope_names[0])
91
+ if len(scope_names) >= 2:
92
+ num = int(scope_names[1])
93
+ pointer = pointer[num]
94
+ try:
95
+ if pointer.shape != array.shape:
96
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
97
+ except ValueError as e:
98
+ e.args += (pointer.shape, array.shape)
99
+ raise
100
+ logger.info(f"Initialize PyTorch weight {name}")
101
+ pointer.data = torch.from_numpy(array)
102
+ return model
103
+
104
+
105
+ def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
106
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
107
+
108
+ if module.scale_attn_weights:
109
+ attn_weights = attn_weights / torch.full(
110
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
111
+ )
112
+
113
+ # Layer-wise attention scaling
114
+ if module.scale_attn_by_inverse_layer_idx:
115
+ attn_weights = attn_weights / float(module.layer_idx + 1)
116
+
117
+ if not module.is_cross_attention:
118
+ # if only "normal" attention layer implements causal mask
119
+ query_length, key_length = query.size(-2), key.size(-2)
120
+ causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
121
+ mask_value = torch.finfo(attn_weights.dtype).min
122
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
123
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
124
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
125
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
126
+
127
+ if attention_mask is not None:
128
+ # Apply the attention mask
129
+ attn_weights = attn_weights + attention_mask
130
+
131
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
132
+
133
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
134
+ attn_weights = attn_weights.type(value.dtype)
135
+ attn_weights = module.attn_dropout(attn_weights)
136
+
137
+ # Mask heads if we want to
138
+ if head_mask is not None:
139
+ attn_weights = attn_weights * head_mask
140
+
141
+ attn_output = torch.matmul(attn_weights, value)
142
+ attn_output = attn_output.transpose(1, 2)
143
+
144
+ return attn_output, attn_weights
145
+
146
+
147
+ class GPT2Attention(nn.Module):
148
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
149
+ super().__init__()
150
+ self.config = config
151
+ max_positions = config.max_position_embeddings
152
+ self.register_buffer(
153
+ "bias",
154
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
155
+ 1, 1, max_positions, max_positions
156
+ ),
157
+ persistent=False,
158
+ )
159
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
160
+
161
+ self.embed_dim = config.hidden_size
162
+ self.num_heads = config.num_attention_heads
163
+ self.head_dim = self.embed_dim // self.num_heads
164
+ self.split_size = self.embed_dim
165
+ if self.head_dim * self.num_heads != self.embed_dim:
166
+ raise ValueError(
167
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
168
+ f" {self.num_heads})."
169
+ )
170
+
171
+ self.scale_attn_weights = config.scale_attn_weights
172
+ self.is_cross_attention = is_cross_attention
173
+
174
+ # Layer-wise attention scaling, reordering, and upcasting
175
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
176
+ self.layer_idx = layer_idx
177
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
178
+
179
+ if self.is_cross_attention:
180
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
181
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
182
+ else:
183
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
184
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
185
+
186
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
187
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
188
+ self.is_causal = True
189
+
190
+ self.pruned_heads = set()
191
+
192
+ def prune_heads(self, heads):
193
+ if len(heads) == 0:
194
+ return
195
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
196
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
197
+
198
+ # Prune conv1d layers
199
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
200
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
201
+
202
+ # Update hyper params
203
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
204
+ self.num_heads = self.num_heads - len(heads)
205
+ self.pruned_heads = self.pruned_heads.union(heads)
206
+
207
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
208
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
209
+ bsz, num_heads, q_seq_len, dk = query.size()
210
+ _, _, k_seq_len, _ = key.size()
211
+
212
+ # Preallocate attn_weights for `baddbmm`
213
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
214
+
215
+ # Compute Scale Factor
216
+ scale_factor = 1.0
217
+ if self.scale_attn_weights:
218
+ scale_factor /= float(value.size(-1)) ** 0.5
219
+
220
+ if self.scale_attn_by_inverse_layer_idx:
221
+ scale_factor /= float(self.layer_idx + 1)
222
+
223
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
224
+ with torch.amp.autocast(query.device.type, enabled=False):
225
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
226
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
227
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
228
+
229
+ if not self.is_cross_attention:
230
+ # if only "normal" attention layer implements causal mask
231
+ query_length, key_length = query.size(-2), key.size(-2)
232
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
233
+ mask_value = torch.finfo(attn_weights.dtype).min
234
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
235
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
236
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
237
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
238
+
239
+ if attention_mask is not None:
240
+ # Apply the attention mask
241
+ attn_weights = attn_weights + attention_mask
242
+
243
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
244
+
245
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
246
+ if attn_weights.dtype != torch.float32:
247
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
248
+ attn_weights = attn_weights.type(value.dtype)
249
+ attn_weights = self.attn_dropout(attn_weights)
250
+
251
+ # Mask heads if we want to
252
+ if head_mask is not None:
253
+ attn_weights = attn_weights * head_mask
254
+
255
+ attn_output = torch.matmul(attn_weights, value)
256
+ attn_output = attn_output.transpose(1, 2)
257
+
258
+ return attn_output, attn_weights
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
263
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
264
+ attention_mask: Optional[torch.FloatTensor] = None,
265
+ head_mask: Optional[torch.FloatTensor] = None,
266
+ encoder_hidden_states: Optional[torch.Tensor] = None,
267
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
268
+ use_cache: Optional[bool] = False,
269
+ output_attentions: Optional[bool] = False,
270
+ **kwargs,
271
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
272
+ if encoder_hidden_states is not None:
273
+ if not hasattr(self, "q_attn"):
274
+ raise ValueError(
275
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
276
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
277
+ )
278
+
279
+ query_states = self.q_attn(hidden_states)
280
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
281
+ attention_mask = encoder_attention_mask
282
+ else:
283
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
284
+
285
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
286
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
287
+
288
+ query_states = query_states.view(shape_q).transpose(1, 2)
289
+ key_states = key_states.view(shape_kv).transpose(1, 2)
290
+ value_states = value_states.view(shape_kv).transpose(1, 2)
291
+
292
+ if layer_past is not None:
293
+ past_key, past_value = layer_past
294
+ key_states = torch.cat((past_key, key_states), dim=-2)
295
+ value_states = torch.cat((past_value, value_states), dim=-2)
296
+
297
+ if use_cache is True:
298
+ present = (key_states, value_states)
299
+ else:
300
+ present = None
301
+
302
+ is_cross_attention = encoder_hidden_states is not None
303
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
304
+
305
+ using_eager = self.config._attn_implementation == "eager"
306
+ attention_interface: Callable = eager_attention_forward
307
+ if self.config._attn_implementation != "eager":
308
+ if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
309
+ using_eager = True
310
+ logger.warning_once(
311
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
312
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
313
+ )
314
+ else:
315
+ # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
316
+ # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
317
+ # not necessarily to eager (if mentionned options are provided).
318
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
319
+
320
+ if using_eager and self.reorder_and_upcast_attn:
321
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
322
+ query_states, key_states, value_states, attention_mask, head_mask
323
+ )
324
+ else:
325
+ attn_output, attn_weights = attention_interface(
326
+ self,
327
+ query_states,
328
+ key_states,
329
+ value_states,
330
+ attention_mask,
331
+ head_mask=head_mask,
332
+ dropout=self.attn_dropout.p if self.training else 0.0,
333
+ is_causal=is_causal,
334
+ **kwargs,
335
+ )
336
+
337
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
338
+ attn_output = self.c_proj(attn_output)
339
+ attn_output = self.resid_dropout(attn_output)
340
+
341
+ outputs = (attn_output, present)
342
+ if output_attentions:
343
+ outputs += (attn_weights,)
344
+
345
+ return outputs # a, present, (attentions)
346
+
347
+ class GPT2AttentionWithDD(GPT2Attention):
348
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
349
+ super().__init__(config, is_cross_attention, layer_idx)
350
+ if config.apply_drift:
351
+ self.drift_param = nn.Parameter(torch.randn(
352
+ config.num_attention_heads, self.head_dim))
353
+ if config.baseline_each_head:
354
+ self.baseline = nn.Parameter(torch.randn(config.num_attention_heads, self.head_dim, dtype=torch.float32) * 1e-3)
355
+ if config.apply_diffusion:
356
+ self.diffusion_param = nn.Parameter(torch.randn(
357
+ config.num_attention_heads, self.head_dim))
358
+
359
+ def apply_diffusion(self, attn_output):
360
+ diffusion_component = self.diffusion_param.unsqueeze(0).unsqueeze(1)
361
+ # Create a noise term with the same shape as attn_output
362
+ # If deterministic, set noise to zero; else, use random noise
363
+ noise = torch.zeros_like(attn_output) if not self.training else torch.randn_like(attn_output)
364
+ attn_output = attn_output + diffusion_component * noise
365
+ return attn_output
366
+
367
+ def gpt_forward_without_cproj(
368
+ self,
369
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
370
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
371
+ attention_mask: Optional[torch.FloatTensor] = None,
372
+ head_mask: Optional[torch.FloatTensor] = None,
373
+ encoder_hidden_states: Optional[torch.Tensor] = None,
374
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
375
+ use_cache: Optional[bool] = False,
376
+ output_attentions: Optional[bool] = False,
377
+ **kwargs,
378
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
379
+ if encoder_hidden_states is not None:
380
+ if not hasattr(self, "q_attn"):
381
+ raise ValueError(
382
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
383
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
384
+ )
385
+
386
+ query_states = self.q_attn(hidden_states)
387
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
388
+ attention_mask = encoder_attention_mask
389
+ else:
390
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
391
+
392
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
393
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
394
+
395
+ query_states = query_states.view(shape_q).transpose(1, 2)
396
+ key_states = key_states.view(shape_kv).transpose(1, 2)
397
+ value_states = value_states.view(shape_kv).transpose(1, 2)
398
+
399
+ if layer_past is not None:
400
+ past_key, past_value = layer_past
401
+ key_states = torch.cat((past_key, key_states), dim=-2)
402
+ value_states = torch.cat((past_value, value_states), dim=-2)
403
+
404
+ if use_cache is True:
405
+ present = (key_states, value_states)
406
+ else:
407
+ present = None
408
+
409
+ is_cross_attention = encoder_hidden_states is not None
410
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
411
+
412
+ using_eager = self.config._attn_implementation == "eager"
413
+ attention_interface: Callable = eager_attention_forward
414
+ if self.config._attn_implementation != "eager":
415
+ if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
416
+ using_eager = True
417
+ logger.warning_once(
418
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
419
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
420
+ )
421
+ else:
422
+ # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
423
+ # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
424
+ # not necessarily to eager (if mentionned options are provided).
425
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
426
+
427
+ if using_eager and self.reorder_and_upcast_attn:
428
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
429
+ query_states, key_states, value_states, attention_mask, head_mask
430
+ )
431
+ else:
432
+ attn_output, attn_weights = attention_interface(
433
+ self,
434
+ query_states,
435
+ key_states,
436
+ value_states,
437
+ attention_mask,
438
+ head_mask=head_mask,
439
+ dropout=self.attn_dropout.p if self.training else 0.0,
440
+ is_causal=is_causal,
441
+ **kwargs,
442
+ )
443
+
444
+ # attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
445
+
446
+ outputs = (attn_output, present)
447
+ if output_attentions:
448
+ outputs += (attn_weights,)
449
+
450
+ return outputs # a, present, (attentions)
451
+
452
+
453
+ def apply_drift(self, attn_output):
454
+ drift_component = torch.sigmoid(self.drift_param.unsqueeze(0).unsqueeze(1))
455
+ if self.config.baseline_each_head:
456
+ attn_output = attn_output + drift_component * \
457
+ (attn_output - self.baseline.unsqueeze(0).unsqueeze(1))
458
+ else:
459
+ attn_output = attn_output + drift_component * attn_output
460
+ return attn_output
461
+
462
+ def forward(self, hidden_states, layer_past=None, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=False, output_attentions=False, **kwargs):
463
+ gpt_attention_output = self.gpt_forward_without_cproj(
464
+ hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, **kwargs
465
+ )
466
+ if len(gpt_attention_output) == 3:
467
+ attn_output, present, attn_weights = gpt_attention_output
468
+ else:
469
+ attn_output, present = gpt_attention_output
470
+ if self.config.apply_drift:
471
+ attn_output = self.apply_drift(attn_output)
472
+ if self.config.apply_diffusion:
473
+ attn_output = self.apply_diffusion(attn_output)
474
+
475
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
476
+ attn_output = self.c_proj(attn_output)
477
+ attn_output = self.resid_dropout(attn_output)
478
+
479
+ if output_attentions:
480
+ return attn_output, present, attn_weights
481
+ return attn_output, present
482
+
483
+
484
+
485
+ class GPT2MLP(nn.Module):
486
+ def __init__(self, intermediate_size, config):
487
+ super().__init__()
488
+ embed_dim = config.hidden_size
489
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
490
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
491
+ self.act = ACT2FN[config.activation_function]
492
+ self.dropout = nn.Dropout(config.resid_pdrop)
493
+
494
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
495
+ hidden_states = self.c_fc(hidden_states)
496
+ hidden_states = self.act(hidden_states)
497
+ hidden_states = self.c_proj(hidden_states)
498
+ hidden_states = self.dropout(hidden_states)
499
+ return hidden_states
500
+
501
+
502
+ class GPT2Block(nn.Module):
503
+ def __init__(self, config, layer_idx=None):
504
+ super().__init__()
505
+ hidden_size = config.hidden_size
506
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
507
+
508
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
509
+ if config.apply_drift or config.apply_diffusion:
510
+ self.attn = GPT2AttentionWithDD(config=config, layer_idx=layer_idx)
511
+ else:
512
+ self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
513
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
514
+
515
+ if config.add_cross_attention:
516
+ self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
517
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
518
+
519
+ self.mlp = GPT2MLP(inner_dim, config)
520
+
521
+ def forward(
522
+ self,
523
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
524
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
525
+ attention_mask: Optional[torch.FloatTensor] = None,
526
+ head_mask: Optional[torch.FloatTensor] = None,
527
+ encoder_hidden_states: Optional[torch.Tensor] = None,
528
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
529
+ use_cache: Optional[bool] = False,
530
+ output_attentions: Optional[bool] = False,
531
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
532
+ residual = hidden_states
533
+ hidden_states = self.ln_1(hidden_states)
534
+ attn_outputs = self.attn(
535
+ hidden_states,
536
+ layer_past=layer_past,
537
+ attention_mask=attention_mask,
538
+ head_mask=head_mask,
539
+ use_cache=use_cache,
540
+ output_attentions=output_attentions,
541
+ )
542
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
543
+ outputs = attn_outputs[1:]
544
+ # residual connection
545
+ hidden_states = attn_output + residual
546
+
547
+ if encoder_hidden_states is not None:
548
+ # add one self-attention block for cross-attention
549
+ if not hasattr(self, "crossattention"):
550
+ raise ValueError(
551
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
552
+ "cross-attention layers by setting `config.add_cross_attention=True`"
553
+ )
554
+ residual = hidden_states
555
+ hidden_states = self.ln_cross_attn(hidden_states)
556
+ cross_attn_outputs = self.crossattention(
557
+ hidden_states,
558
+ attention_mask=attention_mask,
559
+ head_mask=head_mask,
560
+ encoder_hidden_states=encoder_hidden_states,
561
+ encoder_attention_mask=encoder_attention_mask,
562
+ output_attentions=output_attentions,
563
+ )
564
+ attn_output = cross_attn_outputs[0]
565
+ # residual connection
566
+ hidden_states = residual + attn_output
567
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
568
+
569
+ residual = hidden_states
570
+ hidden_states = self.ln_2(hidden_states)
571
+ feed_forward_hidden_states = self.mlp(hidden_states)
572
+ # residual connection
573
+ hidden_states = residual + feed_forward_hidden_states
574
+
575
+ if use_cache:
576
+ outputs = (hidden_states,) + outputs
577
+ else:
578
+ outputs = (hidden_states,) + outputs[1:]
579
+
580
+ return outputs # hidden_states, present, (attentions, cross_attentions)
581
+
582
+
583
+ class DDGPT2PretrainedModel(PreTrainedModel):
584
+ """
585
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
586
+ models.
587
+ """
588
+
589
+ config_class = DDGPT2Config
590
+ load_tf_weights = load_tf_weights_in_gpt2
591
+ base_model_prefix = "transformer"
592
+ is_parallelizable = True
593
+ supports_gradient_checkpointing = True
594
+ _no_split_modules = ["GPT2Block"]
595
+ _skip_keys_device_placement = "past_key_values"
596
+ _supports_flash_attn_2 = True
597
+ _supports_sdpa = True
598
+
599
+ def __init__(self, *inputs, **kwargs):
600
+ super().__init__(*inputs, **kwargs)
601
+
602
+ def _init_weights(self, module):
603
+ """Initialize the weights."""
604
+ if isinstance(module, (nn.Linear, Conv1D)):
605
+ # Slightly different from the TF version which uses truncated_normal for initialization
606
+ # cf https://github.com/pytorch/pytorch/pull/5617
607
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
608
+ if module.bias is not None:
609
+ module.bias.data.zero_()
610
+ elif isinstance(module, nn.Embedding):
611
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
612
+ if module.padding_idx is not None:
613
+ module.weight.data[module.padding_idx].zero_()
614
+ elif isinstance(module, nn.LayerNorm):
615
+ module.bias.data.zero_()
616
+ module.weight.data.fill_(1.0)
617
+
618
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
619
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
620
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
621
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
622
+ #
623
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
624
+ for name, p in module.named_parameters():
625
+ if name == "c_proj.weight":
626
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
627
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
628
+
629
+
630
+ @dataclass
631
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
632
+ """
633
+ Base class for outputs of models predicting if two sentences are consecutive or not.
634
+
635
+ Args:
636
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
637
+ Language modeling loss.
638
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
639
+ Multiple choice classification loss.
640
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
641
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
642
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
643
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
644
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
645
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
646
+ sequence_length, embed_size_per_head)`).
647
+
648
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
649
+ `past_key_values` input) to speed up sequential decoding.
650
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
651
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
652
+ shape `(batch_size, sequence_length, hidden_size)`.
653
+
654
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
655
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
656
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
657
+ sequence_length)`.
658
+
659
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
660
+ self-attention heads.
661
+ """
662
+
663
+ loss: Optional[torch.FloatTensor] = None
664
+ mc_loss: Optional[torch.FloatTensor] = None
665
+ logits: torch.FloatTensor = None
666
+ mc_logits: torch.FloatTensor = None
667
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
668
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
669
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
670
+
671
+
672
+
673
+ class DDGPT2Model(DDGPT2PretrainedModel):
674
+ _supports_param_buffer_assignment = False
675
+
676
+ def __init__(self, config):
677
+ super().__init__(config)
678
+
679
+ self.embed_dim = config.hidden_size
680
+
681
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
682
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
683
+
684
+ self.drop = nn.Dropout(config.embd_pdrop)
685
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
686
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
687
+
688
+ # Model parallel
689
+ self.model_parallel = False
690
+ self.device_map = None
691
+ self.gradient_checkpointing = False
692
+ self._attn_implementation = config._attn_implementation
693
+
694
+ # Initialize weights and apply final processing
695
+ self.post_init()
696
+
697
+ def parallelize(self, device_map=None):
698
+ # Check validity of device_map
699
+ warnings.warn(
700
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
701
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
702
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
703
+ " ...}",
704
+ FutureWarning,
705
+ )
706
+ self.device_map = (
707
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
708
+ )
709
+ assert_device_map(self.device_map, len(self.h))
710
+ self.model_parallel = True
711
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
712
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
713
+ self.wte = self.wte.to(self.first_device)
714
+ self.wpe = self.wpe.to(self.first_device)
715
+ # Load onto devices
716
+ for k, v in self.device_map.items():
717
+ for block in v:
718
+ cuda_device = "cuda:" + str(k)
719
+ self.h[block] = self.h[block].to(cuda_device)
720
+ # ln_f to last
721
+ self.ln_f = self.ln_f.to(self.last_device)
722
+
723
+ def deparallelize(self):
724
+ warnings.warn(
725
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
726
+ FutureWarning,
727
+ )
728
+ self.model_parallel = False
729
+ self.device_map = None
730
+ self.first_device = "cpu"
731
+ self.last_device = "cpu"
732
+ self.wte = self.wte.to("cpu")
733
+ self.wpe = self.wpe.to("cpu")
734
+ for index in range(len(self.h)):
735
+ self.h[index] = self.h[index].to("cpu")
736
+ self.ln_f = self.ln_f.to("cpu")
737
+ torch.cuda.empty_cache()
738
+
739
+ def get_input_embeddings(self):
740
+ return self.wte
741
+
742
+ def set_input_embeddings(self, new_embeddings):
743
+ self.wte = new_embeddings
744
+
745
+ def _prune_heads(self, heads_to_prune):
746
+ """
747
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
748
+ """
749
+ for layer, heads in heads_to_prune.items():
750
+ self.h[layer].attn.prune_heads(heads)
751
+
752
+ def forward(
753
+ self,
754
+ input_ids: Optional[torch.LongTensor] = None,
755
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
756
+ attention_mask: Optional[torch.FloatTensor] = None,
757
+ token_type_ids: Optional[torch.LongTensor] = None,
758
+ position_ids: Optional[torch.LongTensor] = None,
759
+ head_mask: Optional[torch.FloatTensor] = None,
760
+ inputs_embeds: Optional[torch.FloatTensor] = None,
761
+ encoder_hidden_states: Optional[torch.Tensor] = None,
762
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
763
+ use_cache: Optional[bool] = None,
764
+ output_attentions: Optional[bool] = None,
765
+ output_hidden_states: Optional[bool] = None,
766
+ return_dict: Optional[bool] = None,
767
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
768
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
769
+ output_hidden_states = (
770
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
771
+ )
772
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
773
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
774
+
775
+ if input_ids is not None and inputs_embeds is not None:
776
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
777
+ elif input_ids is not None:
778
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
779
+ input_shape = input_ids.size()
780
+ input_ids = input_ids.view(-1, input_shape[-1])
781
+ batch_size = input_ids.shape[0]
782
+ elif inputs_embeds is not None:
783
+ input_shape = inputs_embeds.size()[:-1]
784
+ batch_size = inputs_embeds.shape[0]
785
+ else:
786
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
787
+
788
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
789
+
790
+ if token_type_ids is not None:
791
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
792
+
793
+ if past_key_values is None:
794
+ past_length = 0
795
+ past_key_values = tuple([None] * len(self.h))
796
+ else:
797
+ past_length = past_key_values[0][0].size(-2)
798
+ if position_ids is None:
799
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
800
+ position_ids = position_ids.unsqueeze(0)
801
+
802
+ if inputs_embeds is None:
803
+ inputs_embeds = self.wte(input_ids)
804
+ position_embeds = self.wpe(position_ids)
805
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
806
+
807
+ # Attention mask.
808
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
809
+ attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
810
+ if self._attn_implementation == "flash_attention_2":
811
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
812
+ elif _use_sdpa:
813
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
814
+ attention_mask=attention_mask,
815
+ input_shape=(batch_size, input_shape[-1]),
816
+ inputs_embeds=inputs_embeds,
817
+ past_key_values_length=past_length,
818
+ )
819
+ else:
820
+ if attention_mask is not None:
821
+ # We create a 3D attention mask from a 2D tensor mask.
822
+ # Sizes are [batch_size, 1, 1, to_seq_length]
823
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
824
+ # this attention mask is more simple than the triangular masking of causal attention
825
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
826
+ attention_mask = attention_mask[:, None, None, :]
827
+
828
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
829
+ # masked positions, this operation will create a tensor which is 0.0 for
830
+ # positions we want to attend and the dtype's smallest value for masked positions.
831
+ # Since we are adding it to the raw scores before the softmax, this is
832
+ # effectively the same as removing these entirely.
833
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
834
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
835
+
836
+ # If a 2D or 3D attention mask is provided for the cross-attention
837
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
838
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
839
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
840
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
841
+ if encoder_attention_mask is None:
842
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
843
+ if _use_sdpa:
844
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
845
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
846
+ )
847
+ elif not self._attn_implementation == "flash_attention_2":
848
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
849
+ else:
850
+ encoder_attention_mask = None
851
+
852
+ # Prepare head mask if needed
853
+ # 1.0 in head_mask indicate we keep the head
854
+ # attention_probs has shape bsz x n_heads x N x N
855
+ # head_mask has shape n_layer x batch x n_heads x N x N
856
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
857
+
858
+ if token_type_ids is not None:
859
+ token_type_embeds = self.wte(token_type_ids)
860
+ hidden_states = hidden_states + token_type_embeds
861
+
862
+ hidden_states = self.drop(hidden_states)
863
+
864
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
865
+
866
+ if self.gradient_checkpointing and self.training:
867
+ if use_cache:
868
+ logger.warning_once(
869
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
870
+ )
871
+ use_cache = False
872
+
873
+ presents = () if use_cache else None
874
+ all_self_attentions = () if output_attentions else None
875
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
876
+ all_hidden_states = () if output_hidden_states else None
877
+ for i in range(len(self.h)):
878
+ block, layer_past = self.h[i], past_key_values[i]
879
+ # Model parallel
880
+ if self.model_parallel:
881
+ torch.cuda.set_device(hidden_states.device)
882
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
883
+ if layer_past is not None:
884
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
885
+ # Ensure that attention_mask is always on the same device as hidden_states
886
+ if attention_mask is not None:
887
+ attention_mask = attention_mask.to(hidden_states.device)
888
+ if isinstance(head_mask, torch.Tensor):
889
+ head_mask = head_mask.to(hidden_states.device)
890
+ if output_hidden_states:
891
+ all_hidden_states = all_hidden_states + (hidden_states,)
892
+
893
+ if self.gradient_checkpointing and self.training:
894
+ outputs = self._gradient_checkpointing_func(
895
+ block.__call__,
896
+ hidden_states,
897
+ None,
898
+ attention_mask,
899
+ head_mask[i],
900
+ encoder_hidden_states,
901
+ encoder_attention_mask,
902
+ use_cache,
903
+ output_attentions,
904
+ )
905
+ else:
906
+ outputs = block(
907
+ hidden_states,
908
+ layer_past=layer_past,
909
+ attention_mask=attention_mask,
910
+ head_mask=head_mask[i],
911
+ encoder_hidden_states=encoder_hidden_states,
912
+ encoder_attention_mask=encoder_attention_mask,
913
+ use_cache=use_cache,
914
+ output_attentions=output_attentions,
915
+ )
916
+
917
+ hidden_states = outputs[0]
918
+ if use_cache is True:
919
+ presents = presents + (outputs[1],)
920
+
921
+ if output_attentions:
922
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
923
+ if self.config.add_cross_attention:
924
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
925
+
926
+ # Model Parallel: If it's the last layer for that device, put things on the next device
927
+ if self.model_parallel:
928
+ for k, v in self.device_map.items():
929
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
930
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
931
+
932
+ hidden_states = self.ln_f(hidden_states)
933
+
934
+ hidden_states = hidden_states.view(output_shape)
935
+ # Add last hidden state
936
+ if output_hidden_states:
937
+ all_hidden_states = all_hidden_states + (hidden_states,)
938
+
939
+ if not return_dict:
940
+ return tuple(
941
+ v
942
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
943
+ if v is not None
944
+ )
945
+
946
+ return BaseModelOutputWithPastAndCrossAttentions(
947
+ last_hidden_state=hidden_states,
948
+ past_key_values=presents,
949
+ hidden_states=all_hidden_states,
950
+ attentions=all_self_attentions,
951
+ cross_attentions=all_cross_attentions,
952
+ )
953
+
954
+
955
+ class DDGPT2LMHeadModel(DDGPT2PretrainedModel, GenerationMixin):
956
+ _tied_weights_keys = ["lm_head.weight"]
957
+
958
+ def __init__(self, config):
959
+ super().__init__(config)
960
+ self.transformer = DDGPT2Model(config)
961
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
962
+
963
+ # Model parallel
964
+ self.model_parallel = False
965
+ self.device_map = None
966
+
967
+ # Initialize weights and apply final processing
968
+ self.post_init()
969
+
970
+ def parallelize(self, device_map=None):
971
+ warnings.warn(
972
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
973
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
974
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
975
+ " 0, 'transformer.h.1': 1, ...}",
976
+ FutureWarning,
977
+ )
978
+ self.device_map = (
979
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
980
+ if device_map is None
981
+ else device_map
982
+ )
983
+ assert_device_map(self.device_map, len(self.transformer.h))
984
+ self.transformer.parallelize(self.device_map)
985
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
986
+ self.model_parallel = True
987
+
988
+ def deparallelize(self):
989
+ warnings.warn(
990
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
991
+ FutureWarning,
992
+ )
993
+ self.transformer.deparallelize()
994
+ self.transformer = self.transformer.to("cpu")
995
+ self.lm_head = self.lm_head.to("cpu")
996
+ self.model_parallel = False
997
+ torch.cuda.empty_cache()
998
+
999
+ def get_output_embeddings(self):
1000
+ return self.lm_head
1001
+
1002
+ def set_output_embeddings(self, new_embeddings):
1003
+ self.lm_head = new_embeddings
1004
+
1005
+ def forward(
1006
+ self,
1007
+ input_ids: Optional[torch.LongTensor] = None,
1008
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1009
+ attention_mask: Optional[torch.FloatTensor] = None,
1010
+ token_type_ids: Optional[torch.LongTensor] = None,
1011
+ position_ids: Optional[torch.LongTensor] = None,
1012
+ head_mask: Optional[torch.FloatTensor] = None,
1013
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1014
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1015
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1016
+ labels: Optional[torch.LongTensor] = None,
1017
+ use_cache: Optional[bool] = None,
1018
+ output_attentions: Optional[bool] = None,
1019
+ output_hidden_states: Optional[bool] = None,
1020
+ return_dict: Optional[bool] = None,
1021
+ **kwargs,
1022
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1023
+ r"""
1024
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1025
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1026
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1027
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1028
+ """
1029
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1030
+
1031
+ transformer_outputs = self.transformer(
1032
+ input_ids,
1033
+ past_key_values=past_key_values,
1034
+ attention_mask=attention_mask,
1035
+ token_type_ids=token_type_ids,
1036
+ position_ids=position_ids,
1037
+ head_mask=head_mask,
1038
+ inputs_embeds=inputs_embeds,
1039
+ encoder_hidden_states=encoder_hidden_states,
1040
+ encoder_attention_mask=encoder_attention_mask,
1041
+ use_cache=use_cache,
1042
+ output_attentions=output_attentions,
1043
+ output_hidden_states=output_hidden_states,
1044
+ return_dict=return_dict,
1045
+ )
1046
+ hidden_states = transformer_outputs[0]
1047
+
1048
+ # Set device for model parallelism
1049
+ if self.model_parallel:
1050
+ torch.cuda.set_device(self.transformer.first_device)
1051
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1052
+
1053
+ lm_logits = self.lm_head(hidden_states)
1054
+
1055
+ loss = None
1056
+ if labels is not None:
1057
+ # Flatten the tokens
1058
+ loss = self.loss_function(
1059
+ lm_logits,
1060
+ labels,
1061
+ vocab_size=self.config.vocab_size,
1062
+ **kwargs,
1063
+ )
1064
+
1065
+ if not return_dict:
1066
+ output = (lm_logits,) + transformer_outputs[1:]
1067
+ return ((loss,) + output) if loss is not None else output
1068
+
1069
+ return CausalLMOutputWithCrossAttentions(
1070
+ loss=loss,
1071
+ logits=lm_logits,
1072
+ past_key_values=transformer_outputs.past_key_values,
1073
+ hidden_states=transformer_outputs.hidden_states,
1074
+ attentions=transformer_outputs.attentions,
1075
+ cross_attentions=transformer_outputs.cross_attentions,
1076
+ )
1077
+
1078
+ @staticmethod
1079
+ def _reorder_cache(
1080
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1081
+ ) -> Tuple[Tuple[torch.Tensor]]:
1082
+ """
1083
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1084
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1085
+ beam_idx at every generation step.
1086
+ """
1087
+ return tuple(
1088
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1089
+ for layer_past in past_key_values
1090
+ )
1091
+
1092
+ __all__ = [
1093
+ "DDGPT2LMHeadModel",
1094
+ "DDGPT2Model",
1095
+ "DDGPT2PretrainedModel",
1096
+ "load_tf_weights_in_gpt2",
1097
+ ]
1098
+
1099
+
1100
+ if __name__ == "__main__":
1101
+ cg = GPT2Config.from_pretrained("gpt2-medium")
1102
+ cg.apply_drift = True
1103
+ cg.apply_diffusion = True
1104
+ cg.baseline_each_head = True
1105
+ model = GPT2LMHeadModel(cg)
1106
+ from src.utils.model_utlis import print_trainable_parameters
1107
+ print_trainable_parameters(model)
1108
+ model(torch.randint(0, 10000, (1, 100)))
1109
+ print()