shivanandmn commited on
Commit
9b295f1
·
verified ·
1 Parent(s): 16a3c8d

Model save

Browse files
Files changed (3) hide show
  1. README.md +52 -0
  2. generation_config.json +7 -0
  3. modeling_duo_predict_gpt2.py +901 -0
README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - generated_from_trainer
5
+ model-index:
6
+ - name: duo-predict-gpt2-medium-wikitext
7
+ results: []
8
+ ---
9
+
10
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
11
+ should probably proofread and complete it, then remove this comment. -->
12
+
13
+ # duo-predict-gpt2-medium-wikitext
14
+
15
+ This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
16
+
17
+ ## Model description
18
+
19
+ More information needed
20
+
21
+ ## Intended uses & limitations
22
+
23
+ More information needed
24
+
25
+ ## Training and evaluation data
26
+
27
+ More information needed
28
+
29
+ ## Training procedure
30
+
31
+ ### Training hyperparameters
32
+
33
+ The following hyperparameters were used during training:
34
+ - learning_rate: 0.0001
35
+ - train_batch_size: 8
36
+ - eval_batch_size: 8
37
+ - seed: 42
38
+ - optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
39
+ - lr_scheduler_type: linear
40
+ - lr_scheduler_warmup_ratio: 0.1
41
+ - num_epochs: 5
42
+
43
+ ### Training results
44
+
45
+
46
+
47
+ ### Framework versions
48
+
49
+ - Transformers 4.49.0
50
+ - Pytorch 2.6.0+cu124
51
+ - Datasets 3.3.2
52
+ - Tokenizers 0.21.0
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.49.0",
6
+ "use_cache": false
7
+ }
modeling_duo_predict_gpt2.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """PyTorch OpenAI GPT-2 model, code copied from Huggingface"""
3
+
4
+ import math
5
+ import os
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import Callable, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ CausalLMOutputWithCrossAttentions,
21
+ QuestionAnsweringModelOutput,
22
+ SequenceClassifierOutputWithPast,
23
+ TokenClassifierOutput,
24
+ )
25
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary
26
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
27
+ from transformers.utils import (
28
+ ModelOutput,
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
36
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
37
+ from src.models.modeling_gpt2 import GPT2PreTrainedModel, GPT2Block
38
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
39
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ import torch
44
+
45
+ def create_attention_mask_matrix(tn):
46
+ # Initialize the 10x10 matrix
47
+ tn = tn + 1 ### add 1 for the extra token to create correct matrix, temporary fix
48
+ matrix = torch.zeros(tn, tn)
49
+
50
+ # Define odd columns mask (j=1,3,5,7,9)
51
+ odd_cols = torch.arange(tn) % 2 == 1 # [False, True, False, True, ..., True]
52
+
53
+ # Define row indices
54
+ odd_rows = torch.tensor([x for x in range(1, tn) if x%2==1])
55
+ even_rows = torch.tensor([x for x in range(1, tn) if x%2==0])
56
+
57
+ # For odd rows: ones at odd columns j ≤ i
58
+ # Use tril to get 1s where j ≤ i, then mask with odd columns
59
+ tril_matrix = torch.tril(torch.ones(tn, tn))
60
+ matrix[odd_rows, :] = tril_matrix[odd_rows, :] * odd_cols
61
+
62
+ # For even rows: ones at odd j ≤ i-2, plus j=i and j=i+1
63
+ # Use tril with diagonal=-2 for j ≤ i-2, mask with odd columns
64
+ tril_minus2 = torch.tril(torch.ones(tn, tn), diagonal=-2)
65
+ matrix[even_rows, :] = tril_minus2[even_rows, :] * odd_cols
66
+ # Set specific positions for even rows
67
+ matrix[even_rows, even_rows] = 1 # j = i
68
+ matrix[even_rows, even_rows + 1] = 1 # j = i+1
69
+
70
+ return matrix[1:, 1:].bool()
71
+
72
+ # Efficient implementation equivalent to the following:
73
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
74
+ is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
75
+ L, S = query.size(-2), key.size(-2)
76
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
77
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
78
+ if is_causal:
79
+ assert attn_mask is None
80
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
81
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
82
+ attn_bias.to(query.dtype)
83
+
84
+ if attn_mask is not None:
85
+ if attn_mask.dtype == torch.bool:
86
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
87
+ else:
88
+ attn_bias = attn_mask + attn_bias
89
+
90
+ if enable_gqa:
91
+ key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
92
+ value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
93
+
94
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
95
+ attn_weight += attn_bias
96
+ attn_weight = torch.softmax(attn_weight, dim=-1)
97
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
98
+ return attn_weight @ value
99
+
100
+ def sdpa_attention_forward(
101
+ module: torch.nn.Module,
102
+ query: torch.Tensor,
103
+ key: torch.Tensor,
104
+ value: torch.Tensor,
105
+ attention_mask: Optional[torch.Tensor],
106
+ dropout: float = 0.0,
107
+ scaling: Optional[float] = None,
108
+ is_causal: Optional[bool] = None,
109
+ **kwargs,
110
+ ) -> Tuple[torch.Tensor, None]:
111
+ if hasattr(module, "num_key_value_groups"):
112
+ key = repeat_kv(key, module.num_key_value_groups)
113
+ value = repeat_kv(value, module.num_key_value_groups)
114
+
115
+
116
+ # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
117
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
118
+ query = query.contiguous()
119
+ key = key.contiguous()
120
+ value = value.contiguous()
121
+
122
+
123
+ # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
124
+ # We convert it to a bool for the SDPA kernel that only accepts bools.
125
+ if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
126
+ is_causal = is_causal.item()
127
+
128
+ attn_output = scaled_dot_product_attention(
129
+ query,
130
+ key,
131
+ value,
132
+ attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device),
133
+ dropout_p=dropout,
134
+ scale=scaling,
135
+ is_causal=is_causal,
136
+ )
137
+ attn_output = attn_output.transpose(1, 2).contiguous()
138
+
139
+ return attn_output, None
140
+
141
+
142
+ class DuoPredictGPT2Config(GPT2Config):
143
+ model_type = "duo-predict-gpt2"
144
+ architectures = ["DuoPredictGPT2LMHeadModel"]
145
+
146
+
147
+ class DuoPredictGPT2Attention(nn.Module):
148
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
149
+ super().__init__()
150
+ self.config = config
151
+ max_positions = config.max_position_embeddings
152
+ self.register_buffer(
153
+ "bias",
154
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
155
+ 1, 1, max_positions, max_positions
156
+ ),
157
+ persistent=False,
158
+ )
159
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
160
+
161
+ self.embed_dim = config.hidden_size
162
+ self.num_heads = config.num_attention_heads
163
+ self.head_dim = self.embed_dim // self.num_heads
164
+ self.split_size = self.embed_dim
165
+ if self.head_dim * self.num_heads != self.embed_dim:
166
+ raise ValueError(
167
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
168
+ f" {self.num_heads})."
169
+ )
170
+
171
+ self.scale_attn_weights = config.scale_attn_weights
172
+ self.is_cross_attention = is_cross_attention
173
+
174
+ # Layer-wise attention scaling, reordering, and upcasting
175
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
176
+ self.layer_idx = layer_idx
177
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
178
+
179
+ if self.is_cross_attention:
180
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
181
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
182
+ else:
183
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
184
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
185
+
186
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
187
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
188
+ self.is_causal = True
189
+
190
+ self.pruned_heads = set()
191
+
192
+ def prune_heads(self, heads):
193
+ if len(heads) == 0:
194
+ return
195
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
196
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
197
+
198
+ # Prune conv1d layers
199
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
200
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
201
+
202
+ # Update hyper params
203
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
204
+ self.num_heads = self.num_heads - len(heads)
205
+ self.pruned_heads = self.pruned_heads.union(heads)
206
+
207
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
208
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
209
+ bsz, num_heads, q_seq_len, dk = query.size()
210
+ _, _, k_seq_len, _ = key.size()
211
+
212
+ # Preallocate attn_weights for `baddbmm`
213
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
214
+
215
+ # Compute Scale Factor
216
+ scale_factor = 1.0
217
+ if self.scale_attn_weights:
218
+ scale_factor /= float(value.size(-1)) ** 0.5
219
+
220
+ if self.scale_attn_by_inverse_layer_idx:
221
+ scale_factor /= float(self.layer_idx + 1)
222
+
223
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
224
+ with torch.amp.autocast(query.device.type, enabled=False):
225
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
226
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
227
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
228
+
229
+ if not self.is_cross_attention:
230
+ # if only "normal" attention layer implements causal mask
231
+ query_length, key_length = query.size(-2), key.size(-2)
232
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
233
+ mask_value = torch.finfo(attn_weights.dtype).min
234
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
235
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
236
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
237
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
238
+
239
+ if attention_mask is not None:
240
+ # Apply the attention mask
241
+ attn_weights = attn_weights + attention_mask
242
+
243
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
244
+
245
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
246
+ if attn_weights.dtype != torch.float32:
247
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
248
+ attn_weights = attn_weights.type(value.dtype)
249
+ attn_weights = self.attn_dropout(attn_weights)
250
+
251
+ # Mask heads if we want to
252
+ if head_mask is not None:
253
+ attn_weights = attn_weights * head_mask
254
+
255
+ attn_output = torch.matmul(attn_weights, value)
256
+ attn_output = attn_output.transpose(1, 2)
257
+
258
+ return attn_output, attn_weights
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
263
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
264
+ attention_mask: Optional[torch.FloatTensor] = None,
265
+ head_mask: Optional[torch.FloatTensor] = None,
266
+ encoder_hidden_states: Optional[torch.Tensor] = None,
267
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
268
+ use_cache: Optional[bool] = False,
269
+ output_attentions: Optional[bool] = False,
270
+ **kwargs,
271
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
272
+ if encoder_hidden_states is not None:
273
+ if not hasattr(self, "q_attn"):
274
+ raise ValueError(
275
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
276
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
277
+ )
278
+
279
+ query_states = self.q_attn(hidden_states)
280
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
281
+ attention_mask = encoder_attention_mask
282
+ else:
283
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
284
+
285
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
286
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
287
+
288
+ query_states = query_states.view(shape_q).transpose(1, 2)
289
+ key_states = key_states.view(shape_kv).transpose(1, 2)
290
+ value_states = value_states.view(shape_kv).transpose(1, 2)
291
+
292
+ if layer_past is not None:
293
+ past_key, past_value = layer_past
294
+ key_states = torch.cat((past_key, key_states), dim=-2)
295
+ value_states = torch.cat((past_value, value_states), dim=-2)
296
+
297
+ if use_cache is True:
298
+ present = (key_states, value_states)
299
+ else:
300
+ present = None
301
+
302
+ is_cross_attention = encoder_hidden_states is not None
303
+ is_causal = False #attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
304
+
305
+ using_eager = self.config._attn_implementation == "eager"
306
+ # attention_interface: Callable = eager_attention_forward
307
+ # if self.config._attn_implementation != "eager":
308
+ # if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
309
+ # using_eager = True
310
+ # logger.warning_once(
311
+ # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
312
+ # 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
313
+ # )
314
+ # else:
315
+ # # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
316
+ # # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
317
+ # # not necessarily to eager (if mentionned options are provided).
318
+ # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
319
+
320
+ attention_interface = sdpa_attention_forward
321
+
322
+ if using_eager and self.reorder_and_upcast_attn:
323
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
324
+ query_states, key_states, value_states, attention_mask, head_mask
325
+ )
326
+ else:
327
+ attn_output, attn_weights = attention_interface(
328
+ self,
329
+ query_states,
330
+ key_states,
331
+ value_states,
332
+ attention_mask,
333
+ head_mask=head_mask,
334
+ dropout=self.attn_dropout.p if self.training else 0.0,
335
+ is_causal=is_causal,
336
+ **kwargs,
337
+ )
338
+
339
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
340
+ attn_output = self.c_proj(attn_output)
341
+ attn_output = self.resid_dropout(attn_output)
342
+
343
+ outputs = (attn_output, present)
344
+ if output_attentions:
345
+ outputs += (attn_weights,)
346
+
347
+ return outputs # a, present, (attentions)
348
+
349
+
350
+ class DuoPredictGPT2MLP(nn.Module):
351
+ def __init__(self, intermediate_size, config):
352
+ super().__init__()
353
+ embed_dim = config.hidden_size
354
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
355
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
356
+ self.act = ACT2FN[config.activation_function]
357
+ self.dropout = nn.Dropout(config.resid_pdrop)
358
+
359
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
360
+ hidden_states = self.c_fc(hidden_states)
361
+ hidden_states = self.act(hidden_states)
362
+ hidden_states = self.c_proj(hidden_states)
363
+ hidden_states = self.dropout(hidden_states)
364
+ return hidden_states
365
+
366
+
367
+ class DuoPredictGPT2Block(nn.Module):
368
+ def __init__(self, config, layer_idx=None):
369
+ super().__init__()
370
+ hidden_size = config.hidden_size
371
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
372
+
373
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
374
+ self.attn = DuoPredictGPT2Attention(config=config, layer_idx=layer_idx)
375
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
376
+
377
+ if config.add_cross_attention:
378
+ self.crossattention = DuoPredictGPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
379
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
380
+
381
+ self.mlp = DuoPredictGPT2MLP(inner_dim, config)
382
+
383
+ def forward(
384
+ self,
385
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
386
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
387
+ attention_mask: Optional[torch.FloatTensor] = None,
388
+ head_mask: Optional[torch.FloatTensor] = None,
389
+ encoder_hidden_states: Optional[torch.Tensor] = None,
390
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
391
+ use_cache: Optional[bool] = False,
392
+ output_attentions: Optional[bool] = False,
393
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
394
+ residual = hidden_states
395
+ hidden_states = self.ln_1(hidden_states)
396
+ attn_outputs = self.attn(
397
+ hidden_states,
398
+ layer_past=layer_past,
399
+ attention_mask=attention_mask,
400
+ head_mask=head_mask,
401
+ use_cache=use_cache,
402
+ output_attentions=output_attentions,
403
+ )
404
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
405
+ outputs = attn_outputs[1:]
406
+ # residual connection
407
+ hidden_states = attn_output + residual
408
+
409
+ if encoder_hidden_states is not None:
410
+ # add one self-attention block for cross-attention
411
+ if not hasattr(self, "crossattention"):
412
+ raise ValueError(
413
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
414
+ "cross-attention layers by setting `config.add_cross_attention=True`"
415
+ )
416
+ residual = hidden_states
417
+ hidden_states = self.ln_cross_attn(hidden_states)
418
+ cross_attn_outputs = self.crossattention(
419
+ hidden_states,
420
+ attention_mask=attention_mask,
421
+ head_mask=head_mask,
422
+ encoder_hidden_states=encoder_hidden_states,
423
+ encoder_attention_mask=encoder_attention_mask,
424
+ output_attentions=output_attentions,
425
+ )
426
+ attn_output = cross_attn_outputs[0]
427
+ # residual connection
428
+ hidden_states = residual + attn_output
429
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
430
+
431
+ residual = hidden_states
432
+ hidden_states = self.ln_2(hidden_states)
433
+ feed_forward_hidden_states = self.mlp(hidden_states)
434
+ # residual connection
435
+ hidden_states = residual + feed_forward_hidden_states
436
+
437
+ if use_cache:
438
+ outputs = (hidden_states,) + outputs
439
+ else:
440
+ outputs = (hidden_states,) + outputs[1:]
441
+
442
+ return outputs # hidden_states, present, (attentions, cross_attentions)
443
+
444
+
445
+ class DuoPredictGPT2PretrainedModel(GPT2PreTrainedModel):
446
+ config_class = DuoPredictGPT2Config
447
+
448
+
449
+ class DuoPredictGPT2Model(DuoPredictGPT2PretrainedModel):
450
+ _supports_param_buffer_assignment = False
451
+
452
+ def __init__(self, config):
453
+ super().__init__(config)
454
+
455
+ self.embed_dim = config.hidden_size
456
+
457
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
458
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
459
+
460
+ self.drop = nn.Dropout(config.embd_pdrop)
461
+ self.h = nn.ModuleList([DuoPredictGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
462
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
463
+
464
+ # Model parallel
465
+ self.model_parallel = False
466
+ self.device_map = None
467
+ self.gradient_checkpointing = False
468
+ self._attn_implementation = config._attn_implementation
469
+
470
+ # Initialize weights and apply final processing
471
+ self.post_init()
472
+
473
+ def parallelize(self, device_map=None):
474
+ # Check validity of device_map
475
+ warnings.warn(
476
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
477
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
478
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
479
+ " ...}",
480
+ FutureWarning,
481
+ )
482
+ self.device_map = (
483
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
484
+ )
485
+ assert_device_map(self.device_map, len(self.h))
486
+ self.model_parallel = True
487
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
488
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
489
+ self.wte = self.wte.to(self.first_device)
490
+ self.wpe = self.wpe.to(self.first_device)
491
+ # Load onto devices
492
+ for k, v in self.device_map.items():
493
+ for block in v:
494
+ cuda_device = "cuda:" + str(k)
495
+ self.h[block] = self.h[block].to(cuda_device)
496
+ # ln_f to last
497
+ self.ln_f = self.ln_f.to(self.last_device)
498
+
499
+
500
+ def deparallelize(self):
501
+ warnings.warn(
502
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
503
+ FutureWarning,
504
+ )
505
+ self.model_parallel = False
506
+ self.device_map = None
507
+ self.first_device = "cpu"
508
+ self.last_device = "cpu"
509
+ self.wte = self.wte.to("cpu")
510
+ self.wpe = self.wpe.to("cpu")
511
+ for index in range(len(self.h)):
512
+ self.h[index] = self.h[index].to("cpu")
513
+ self.ln_f = self.ln_f.to("cpu")
514
+ torch.cuda.empty_cache()
515
+
516
+ def get_input_embeddings(self):
517
+ return self.wte
518
+
519
+ def set_input_embeddings(self, new_embeddings):
520
+ self.wte = new_embeddings
521
+
522
+ def _prune_heads(self, heads_to_prune):
523
+ """
524
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
525
+ """
526
+ for layer, heads in heads_to_prune.items():
527
+ self.h[layer].attn.prune_heads(heads)
528
+
529
+
530
+ def forward(
531
+ self,
532
+ input_ids: Optional[torch.LongTensor] = None,
533
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
534
+ attention_mask: Optional[torch.FloatTensor] = None,
535
+ token_type_ids: Optional[torch.LongTensor] = None,
536
+ position_ids: Optional[torch.LongTensor] = None,
537
+ head_mask: Optional[torch.FloatTensor] = None,
538
+ inputs_embeds: Optional[torch.FloatTensor] = None,
539
+ encoder_hidden_states: Optional[torch.Tensor] = None,
540
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
541
+ use_cache: Optional[bool] = None,
542
+ output_attentions: Optional[bool] = None,
543
+ output_hidden_states: Optional[bool] = None,
544
+ return_dict: Optional[bool] = None,
545
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
546
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
547
+ output_hidden_states = (
548
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
549
+ )
550
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
551
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
552
+
553
+ if input_ids is not None and inputs_embeds is not None:
554
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
555
+ elif input_ids is not None:
556
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
557
+ input_shape = input_ids.size()
558
+ input_ids = input_ids.view(-1, input_shape[-1])
559
+ batch_size = input_ids.shape[0]
560
+ elif inputs_embeds is not None:
561
+ input_shape = inputs_embeds.size()[:-1]
562
+ batch_size = inputs_embeds.shape[0]
563
+ else:
564
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
565
+
566
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
567
+
568
+ if token_type_ids is not None:
569
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
570
+
571
+ if past_key_values is None:
572
+ past_length = 0
573
+ past_key_values = tuple([None] * len(self.h))
574
+ else:
575
+ past_length = past_key_values[0][0].size(-2)
576
+ if position_ids is None:
577
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
578
+ position_ids = position_ids.unsqueeze(0)
579
+ position_ids = position_ids[:, :self.config.max_position_embeddings] #TODO: remember
580
+
581
+ if inputs_embeds is None:
582
+ inputs_embeds = self.wte(input_ids)
583
+ position_embeds = self.wpe(position_ids)
584
+ ###TODO: correctly initialized
585
+ hidden_states = torch.empty((batch_size, input_shape[-1], self.embed_dim), device=device)
586
+ hidden_states[:, ::2] = inputs_embeds[:, ::2] + position_embeds.to(inputs_embeds.device)
587
+ hidden_states[:, 1::2] = inputs_embeds[:, 1::2] + position_embeds[:, :self.config.max_position_embeddings-1].to(inputs_embeds.device)
588
+
589
+ # Attention mask.
590
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
591
+ attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
592
+ if self._attn_implementation == "flash_attention_2":
593
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
594
+ elif _use_sdpa:
595
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
596
+ attention_mask=attention_mask,
597
+ input_shape=(batch_size, input_shape[-1]),
598
+ inputs_embeds=inputs_embeds,
599
+ past_key_values_length=past_length,
600
+ )
601
+ else:
602
+ if attention_mask is not None:
603
+ # We create a 3D attention mask from a 2D tensor mask.
604
+ # Sizes are [batch_size, 1, 1, to_seq_length]
605
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
606
+ # this attention mask is more simple than the triangular masking of causal attention
607
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
608
+ attention_mask = attention_mask[:, None, None, :]
609
+
610
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
611
+ # masked positions, this operation will create a tensor which is 0.0 for
612
+ # positions we want to attend and the dtype's smallest value for masked positions.
613
+ # Since we are adding it to the raw scores before the softmax, this is
614
+ # effectively the same as removing these entirely.
615
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
616
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
617
+
618
+ # If a 2D or 3D attention mask is provided for the cross-attention
619
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
620
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
621
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
622
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
623
+ if encoder_attention_mask is None:
624
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
625
+ if _use_sdpa:
626
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
627
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
628
+ )
629
+ elif not self._attn_implementation == "flash_attention_2":
630
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
631
+ else:
632
+ encoder_attention_mask = None
633
+
634
+ # Prepare head mask if needed
635
+ # 1.0 in head_mask indicate we keep the head
636
+ # attention_probs has shape bsz x n_heads x N x N
637
+ # head_mask has shape n_layer x batch x n_heads x N x N
638
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
639
+
640
+ if token_type_ids is not None:
641
+ token_type_embeds = self.wte(token_type_ids)
642
+ hidden_states = hidden_states + token_type_embeds
643
+
644
+ hidden_states = self.drop(hidden_states)
645
+
646
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
647
+
648
+ if self.gradient_checkpointing and self.training:
649
+ if use_cache:
650
+ logger.warning_once(
651
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
652
+ )
653
+ use_cache = False
654
+
655
+ presents = () if use_cache else None
656
+ all_self_attentions = () if output_attentions else None
657
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
658
+ all_hidden_states = () if output_hidden_states else None
659
+ for i in range(len(self.h)):
660
+ block, layer_past = self.h[i], past_key_values[i]
661
+ # Model parallel
662
+ if self.model_parallel:
663
+ torch.cuda.set_device(hidden_states.device)
664
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
665
+ if layer_past is not None:
666
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
667
+ # Ensure that attention_mask is always on the same device as hidden_states
668
+ if attention_mask is not None:
669
+ attention_mask = attention_mask.to(hidden_states.device)
670
+ if isinstance(head_mask, torch.Tensor):
671
+ head_mask = head_mask.to(hidden_states.device)
672
+ if output_hidden_states:
673
+ all_hidden_states = all_hidden_states + (hidden_states,)
674
+
675
+ if self.gradient_checkpointing and self.training:
676
+ outputs = self._gradient_checkpointing_func(
677
+ block.__call__,
678
+ hidden_states,
679
+ None,
680
+ attention_mask,
681
+ head_mask[i],
682
+ encoder_hidden_states,
683
+ encoder_attention_mask,
684
+ use_cache,
685
+ output_attentions,
686
+ )
687
+ else:
688
+ outputs = block(
689
+ hidden_states,
690
+ layer_past=layer_past,
691
+ attention_mask=attention_mask,
692
+ head_mask=head_mask[i],
693
+ encoder_hidden_states=encoder_hidden_states,
694
+ encoder_attention_mask=encoder_attention_mask,
695
+ use_cache=use_cache,
696
+ output_attentions=output_attentions,
697
+ )
698
+
699
+ hidden_states = outputs[0]
700
+ if use_cache is True:
701
+ presents = presents + (outputs[1],)
702
+
703
+ if output_attentions:
704
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
705
+ if self.config.add_cross_attention:
706
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
707
+
708
+ # Model Parallel: If it's the last layer for that device, put things on the next device
709
+ if self.model_parallel:
710
+ for k, v in self.device_map.items():
711
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
712
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
713
+
714
+ hidden_states = self.ln_f(hidden_states)
715
+
716
+ hidden_states = hidden_states.view(output_shape)
717
+ # Add last hidden state
718
+ if output_hidden_states:
719
+ all_hidden_states = all_hidden_states + (hidden_states,)
720
+
721
+ if not return_dict:
722
+ return tuple(
723
+ v
724
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
725
+ if v is not None
726
+ )
727
+
728
+ return BaseModelOutputWithPastAndCrossAttentions(
729
+ last_hidden_state=hidden_states,
730
+ past_key_values=presents,
731
+ hidden_states=all_hidden_states,
732
+ attentions=all_self_attentions,
733
+ cross_attentions=all_cross_attentions,
734
+ )
735
+
736
+
737
+ class DuoPredictGPT2LMHeadModel(DuoPredictGPT2PretrainedModel, GenerationMixin):
738
+ _tied_weights_keys = ["lm_head.weight"]
739
+
740
+ def __init__(self, config):
741
+ super().__init__(config)
742
+ self.transformer = DuoPredictGPT2Model(config)
743
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
744
+
745
+ # Model parallel
746
+ self.model_parallel = False
747
+ self.device_map = None
748
+
749
+ # Initialize weights and apply final processing
750
+ self.post_init()
751
+
752
+
753
+ def parallelize(self, device_map=None):
754
+ warnings.warn(
755
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
756
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
757
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
758
+ " 0, 'transformer.h.1': 1, ...}",
759
+ FutureWarning,
760
+ )
761
+ self.device_map = (
762
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
763
+ if device_map is None
764
+ else device_map
765
+ )
766
+ assert_device_map(self.device_map, len(self.transformer.h))
767
+ self.transformer.parallelize(self.device_map)
768
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
769
+ self.model_parallel = True
770
+
771
+ def deparallelize(self):
772
+ warnings.warn(
773
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
774
+ FutureWarning,
775
+ )
776
+ self.transformer.deparallelize()
777
+ self.transformer = self.transformer.to("cpu")
778
+ self.lm_head = self.lm_head.to("cpu")
779
+ self.model_parallel = False
780
+ torch.cuda.empty_cache()
781
+
782
+ def get_output_embeddings(self):
783
+ return self.lm_head
784
+
785
+ def set_output_embeddings(self, new_embeddings):
786
+ self.lm_head = new_embeddings
787
+
788
+ def forward(
789
+ self,
790
+ input_ids: Optional[torch.LongTensor] = None,
791
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
792
+ attention_mask: Optional[torch.FloatTensor] = None,
793
+ token_type_ids: Optional[torch.LongTensor] = None,
794
+ position_ids: Optional[torch.LongTensor] = None,
795
+ head_mask: Optional[torch.FloatTensor] = None,
796
+ inputs_embeds: Optional[torch.FloatTensor] = None,
797
+ encoder_hidden_states: Optional[torch.Tensor] = None,
798
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
799
+ labels: Optional[torch.LongTensor] = None,
800
+ use_cache: Optional[bool] = None,
801
+ output_attentions: Optional[bool] = None,
802
+ output_hidden_states: Optional[bool] = None,
803
+ return_dict: Optional[bool] = None,
804
+ **kwargs,
805
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
806
+ r"""
807
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
808
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
809
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
810
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
811
+ """
812
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
813
+
814
+ transformer_outputs = self.transformer(
815
+ input_ids,
816
+ past_key_values=past_key_values,
817
+ attention_mask=attention_mask,
818
+ token_type_ids=token_type_ids,
819
+ position_ids=position_ids,
820
+ head_mask=head_mask,
821
+ inputs_embeds=inputs_embeds,
822
+ encoder_hidden_states=encoder_hidden_states,
823
+ encoder_attention_mask=encoder_attention_mask,
824
+ use_cache=use_cache,
825
+ output_attentions=output_attentions,
826
+ output_hidden_states=output_hidden_states,
827
+ return_dict=return_dict,
828
+ )
829
+ hidden_states = transformer_outputs[0]
830
+
831
+ # Set device for model parallelism
832
+ if self.model_parallel:
833
+ torch.cuda.set_device(self.transformer.first_device)
834
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
835
+
836
+ lm_logits = self.lm_head(hidden_states)
837
+
838
+ loss = None
839
+ if labels is not None:
840
+ # Flatten the tokens
841
+ total_labels = torch.full((lm_logits.shape[:2]), -100, dtype=input_ids.dtype, device=input_ids.device)
842
+ total_labels[:, :-1:2] = labels[:, 1: ]
843
+ total_labels[:, 1:-1:2] = labels[:, :-1]
844
+ loss = self.loss_function(
845
+ lm_logits,
846
+ total_labels,
847
+ vocab_size=self.config.vocab_size,
848
+ **kwargs,
849
+ )
850
+
851
+ if not return_dict:
852
+ output = (lm_logits,) + transformer_outputs[1:]
853
+ return ((loss,) + output) if loss is not None else output
854
+
855
+ return CausalLMOutputWithCrossAttentions(
856
+ loss=loss,
857
+ logits=lm_logits,
858
+ past_key_values=transformer_outputs.past_key_values,
859
+ hidden_states=transformer_outputs.hidden_states,
860
+ attentions=transformer_outputs.attentions,
861
+ cross_attentions=transformer_outputs.cross_attentions,
862
+ )
863
+
864
+ @staticmethod
865
+ def _reorder_cache(
866
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
867
+ ) -> Tuple[Tuple[torch.Tensor]]:
868
+ """
869
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
870
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
871
+ beam_idx at every generation step.
872
+ """
873
+ return tuple(
874
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
875
+ for layer_past in past_key_values
876
+ )
877
+
878
+
879
+
880
+ from transformers import AutoConfig, AutoModel
881
+ AutoConfig.register("duo-predict-gpt2", DuoPredictGPT2Config)
882
+ AutoModel.register(DuoPredictGPT2Config, DuoPredictGPT2LMHeadModel)
883
+
884
+
885
+ __all__ = [
886
+ "DuoPredictGPT2LMHeadModel",
887
+ "DuoPredictGPT2Model",
888
+ "DuoPredictGPT2Config",
889
+ "DuoPredictGPT2Attention",
890
+ "DuoPredictGPT2MLP",
891
+ "DuoPredictGPT2Block",
892
+ ]
893
+
894
+
895
+ if __name__ == "__main__":
896
+ cg = DuoPredictGPT2Config()
897
+ model = DuoPredictGPT2LMHeadModel(cg)
898
+ from src.utils.model_utlis import print_trainable_parameters
899
+ print_trainable_parameters(model)
900
+ model(torch.randint(0, 10000, (1, 100)))
901
+ print()