Winmodel commited on
Commit
ed5db8f
·
verified ·
1 Parent(s): 44bf29c

Upload GPT

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +17 -0
  3. hf_configuration.py +21 -0
  4. hf_modeling.py +291 -0
  5. pytorch_model.bin +3 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GPT"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "hf_configuration.ExGPTConfig",
7
+ "AutoModel": "hf_modeling.GPT"
8
+ },
9
+ "block_size": 1024,
10
+ "model_type": "ExGPT",
11
+ "n_embd": 768,
12
+ "n_head": 12,
13
+ "n_layer": 12,
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.48.3",
16
+ "vocab_size": 50304
17
+ }
hf_configuration.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class ExGPTConfig(PretrainedConfig):
5
+ model_type = "ExGPT"
6
+
7
+ def __init__(
8
+ self,
9
+ block_size: int = 1024, # Ctx length?
10
+ vocab_size: int = 50527, # 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
11
+ n_layer: int = 12,
12
+ n_head: int = 12,
13
+ n_embd: int = 768,
14
+ **kwargs
15
+ ):
16
+ self.block_size = block_size
17
+ self.vocab_size = vocab_size
18
+ self.n_layer = n_layer
19
+ self.n_head = n_head
20
+ self.n_embd = n_embd
21
+ super().__init__(**kwargs)
hf_modeling.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ import math
6
+ import inspect
7
+ import os
8
+ from hellaswag import render_example, iterate_examples
9
+ from tqdm import tqdm
10
+ from hf_configuration import ExGPTConfig
11
+ from transformers import PreTrainedModel
12
+
13
+ # ==================================================
14
+
15
+ class CausalSelfAttention(nn.Module):
16
+
17
+ def __init__(self, config):
18
+ super().__init__()
19
+ assert config.n_embd % config.n_head == 0
20
+ # key, query, value projection for all heads
21
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
22
+ # output projection
23
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
24
+ self.c_proj.NANOGPT_SCALE_INIT = 1 # a flag
25
+ # regularization
26
+ self.n_head = config.n_head
27
+ self.n_embd = config.n_embd
28
+ # not really a 'bias', more of a mask
29
+ self.register_buffer('bias', torch.tril(torch.ones(config.block_size, config.block_size))
30
+ .view(1, 1, config.block_size, config.block_size)) # Batch, head, the table x2 รึ
31
+
32
+ def forward(self, x):
33
+ B, T, C = x.size() # batch, seq len, embed dim
34
+ qkv = self.c_attn(x) # project first, reshape later for each heads
35
+ q, k, v = qkv.split(self.n_embd, dim=2)
36
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
37
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
38
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
39
+
40
+ # begin the fk huge quadratic table
41
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
42
+ # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
43
+ # att = F.softmax(att, dim = -1)
44
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
45
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
46
+
47
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
48
+ # output projection
49
+ out = self.c_proj(y)
50
+ return out
51
+
52
+ class MLP(nn.Module):
53
+ "change it to SwiGLU"
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.gate = nn.Linear(config.n_embd, 4 * config.n_embd)
57
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
58
+ self.silu = nn.SiLU()
59
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
60
+ self.c_proj.NANOGPT_SCALE_INIT = 1 # a flag
61
+
62
+ def forward(self, x):
63
+ # x = self.c_fc(x)
64
+ # x = self.gelu(x)
65
+ # x = self.c_proj(x)
66
+ x = self.c_proj(self.silu(self.c_fc(x) * self.gate(x)))
67
+ return x
68
+
69
+ class Block(nn.Module):
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.ln_1 = nn.RMSNorm(config.n_embd)
73
+ self.attn = CausalSelfAttention(config)
74
+ self.ln_2 = nn.RMSNorm(config.n_embd)
75
+ self.mlp = MLP(config)
76
+
77
+ def forward(self, x):
78
+ x = x + self.attn(self.ln_1(x))
79
+ x = x + self.mlp(self.ln_2(x))
80
+ return x
81
+
82
+ class GPT(PreTrainedModel):
83
+
84
+ def __init__(self, config):
85
+ super().__init__(config)
86
+ self.config = config
87
+
88
+ self.transformer = nn.ModuleDict(dict(
89
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
90
+ wpe = nn.Embedding(config.block_size, config.n_embd), # Learned positional embedding
91
+ h = nn.ModuleList(Block(config) for _ in range(config.n_layer)),
92
+ ln_f = nn.RMSNorm(config.n_embd),
93
+ ))
94
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
95
+
96
+ # Weight sharing scheme
97
+ self.transformer.wte.weight = self.lm_head.weight # GPT2/transformers is all you need's style
98
+ # Worse trainging loss though. From my observation
99
+
100
+ # init params
101
+ # Apply fn recursively to every submodule (as returned by .children()) as well as self.
102
+ self.apply(self._init_weights)
103
+
104
+ def _init_weights(self, module): # iterate over each module เลยสินะ
105
+ if isinstance(module, nn.Linear):
106
+ std = 0.02
107
+ if hasattr(module, 'NANOGPT_SCALE_INIT'): # if there is the flag
108
+ std *= (2 * self.config.n_layer) ** -0.5
109
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std) # typicall, std is 1/sqrt(feature)
110
+ if module.bias is not None:
111
+ torch.nn.init.zeros_(module.bias)
112
+ elif isinstance(module, nn.Embedding):
113
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
114
+
115
+ def forward(self, idx, target=None):
116
+ # idx is of shape (B, T)
117
+ B, T = idx.size()
118
+ assert T <= self.config.block_size, f"Cannot forward a sequence of length {T}, blocksize is only {self.config.block_size}"
119
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
120
+ tok_emb = self.transformer.wte(idx)
121
+
122
+ # with torch.autocast(device_type=device, enabled=False):
123
+ pos_emb = self.transformer.wpe(pos)
124
+ x = tok_emb + pos_emb
125
+ # forward the block of the transformer
126
+ for block in self.transformer.h:
127
+ x = block(x)
128
+ # forward the final layernorm and the classifier
129
+ x = self.transformer.ln_f(x)
130
+ loss = None
131
+ logits = self.lm_head(x) # (B, T, vocab_size)
132
+ if target is not None:
133
+ loss = F.cross_entropy(logits.view(-1,logits.size(-1)), target.view(-1)) # view -1 to flatten B,T dim to B*T for target, and logits.view(-1,logits.size(-1)) to get logit into shape B*T, vocab
134
+ return logits, loss
135
+ # Typo แดกโลก
136
+
137
+ @classmethod
138
+ def from_pretrained(cls, model_type):
139
+ """Loads pretrained GPT-2 model weights from huggingface"""
140
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
141
+ from transformers import GPT2LMHeadModel
142
+ print("loading weights from pretrained gpt: %s" % model_type)
143
+
144
+ # n_layer, n_head and n_embd are determined from model_type
145
+ config_args = {
146
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
147
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
148
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
149
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
150
+ }[model_type]
151
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
152
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
153
+ # create a from-scratch initialized minGPT model
154
+ config = GPTConfig(**config_args)
155
+ model = GPT(config)
156
+ sd = model.state_dict()
157
+ sd_keys = sd.keys()
158
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
159
+
160
+ # init a huggingface/transformers model
161
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
162
+ sd_hf = model_hf.state_dict()
163
+
164
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
165
+ sd_keys_hf = sd_hf.keys()
166
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
167
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
168
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
169
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
170
+ # this means that we have to transpose these weights when we import them
171
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
172
+ for k in sd_keys_hf:
173
+ if any(k.endswith(w) for w in transposed):
174
+ # special treatment for the Conv1D weights we need to transpose
175
+ assert sd_hf[k].shape[::-1] == sd[k].shape
176
+ with torch.no_grad():
177
+ sd[k].copy_(sd_hf[k].t())
178
+ else:
179
+ # vanilla copy over the other parameters
180
+ assert sd_hf[k].shape == sd[k].shape
181
+ with torch.no_grad():
182
+ sd[k].copy_(sd_hf[k])
183
+
184
+ return model
185
+
186
+ def configure_optimizers(self, weight_decay, learning_rate, device):
187
+ # start wit all of the candidate parameters (that require grad)
188
+ param_dict = {pn: p for pn, p in self.named_parameters()}
189
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
190
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
191
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorm don't.
192
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
193
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
194
+ optim_groups = [
195
+ {'params': decay_params, 'weight_decay': weight_decay},
196
+ {'params': nodecay_params, 'weight_decay': 0.0}
197
+ ]
198
+ num_decay_params = sum(p.numel() for p in decay_params)
199
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
200
+ print(f"num decayed parameter tensor: {len(decay_params)}, with {num_decay_params:,} paramters")
201
+ print(f"num non-decayed parameter tensor: {len(nodecay_params)}, with {num_nodecay_params:,} paramters")
202
+ # Create AdamW optimizer and use fused version if it is available
203
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
204
+ use_fused = fused_available and 'cuda' in device
205
+ print(f"using fused AdamW: {use_fused}")
206
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
207
+ return optimizer
208
+
209
+ # ===============================================================================================
210
+ num_return_sequences = 5
211
+ max_length = 30
212
+
213
+ # =================================================================================================
214
+ import tiktoken
215
+ import numpy as np
216
+
217
+ def load_tokens(filename):
218
+ npt = np.load(filename)
219
+ ptt = torch.tensor(npt, dtype=torch.long)
220
+ return ptt
221
+
222
+ class DataLoaderLite:
223
+ def __init__(self, B, T, process_rank, num_processes, split):
224
+ self.B = B
225
+ self.T = T
226
+ self.process_rank = process_rank
227
+ self.num_processes = num_processes
228
+ assert split in {'train', 'val'}
229
+
230
+ # get the shard filename
231
+ data_root = "edu_fineweb10B"
232
+ shards = os.listdir(data_root)
233
+ shards = [s for s in shards if split in s]
234
+ shards = sorted(shards)
235
+ shards = [os.path.join(data_root, s) for s in shards]
236
+ self.shards = shards
237
+ assert len(shards) > 0, f"no shards found in the split {split}"
238
+ if master_process:
239
+ print(f"found {len(shards)} shards for split {split}")
240
+
241
+ # state
242
+ # self.current_position = 0
243
+ # We wanna stride out dall the processes
244
+ # self.current_shard = 0
245
+ # self.tokens = load_tokens(self.shards[self.current_shard])
246
+ # self.current_position = self.B * self.T * self.process_rank
247
+ self.reset() # reset take care of the trouble
248
+
249
+ def reset(self):
250
+ # state, init at shard zero
251
+ self.current_shard = 0
252
+ self.tokens = load_tokens(self.shards[self.current_shard])
253
+ self.current_position = self.B * self.T * self.process_rank
254
+
255
+ def next_batch(self):
256
+ B, T = self.B, self.T
257
+ buf = self.tokens[self.current_position:self.current_position+B*T+1]
258
+ x = (buf[:-1]).view(B, T) # input
259
+ y = (buf[1:]).view(B, T) # target
260
+
261
+ # advance the position in the tensor
262
+ # self.current_position += B*T
263
+ self.current_position += B * T * self.num_processes
264
+ if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): # When we run out of token in a chard, we advance to the next shard
265
+ self.current_shard = (self.current_shard + 1) % len(self.shards)
266
+ self.tokens = load_tokens(self.shards[self.current_shard])
267
+ self.current_position = B * T * self.process_rank
268
+ return x, y
269
+
270
+ # -----------------------------------------------------------------------------
271
+ # helper function for HellaSwag eval
272
+ # takes tokens, mask, and logits, returns the index of the completion with the lowest loss
273
+
274
+ def get_most_likely_row(tokens, mask, logits):
275
+ # evaluate the autoregressive loss at all positions
276
+ shift_logits = (logits[..., :-1, :]).contiguous()
277
+ shift_tokens = (tokens[..., 1:]).contiguous()
278
+ flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
279
+ flat_shift_tokens = shift_tokens.view(-1)
280
+ shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
281
+ shift_losses = shift_losses.view(tokens.size(0), -1)
282
+ # now get the average loss just for the completion region (where mask == 1), in each row
283
+ shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
284
+ masked_shift_losses = shift_losses * shift_mask
285
+ # sum and divide by the number of 1s in the mask
286
+ sum_loss = masked_shift_losses.sum(dim=1)
287
+ avg_loss = sum_loss / shift_mask.sum(dim=1)
288
+ # now we have a loss for each of the 4 completions
289
+ # the one with the lowest loss should be the most likely
290
+ pred_norm = avg_loss.argmin().item()
291
+ return pred_norm
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df3ca80a6243fa9c565c117bfba39515c188972a61e7754f5fa9ea7d32c75f70
3
+ size 661604817