nroggendorff commited on
Commit
9982b19
·
verified ·
1 Parent(s): 6fd0103

im never using chatgpt again

Files changed (1) hide show
  1. train.py +247 -120
train.py CHANGED
@@ -1,117 +1,107 @@
 
 
 
1
  import trl
2
  from transformers import (
3
- AutoTokenizer, LlamaConfig, LlamaForCausalLM,
4
- PreTrainedTokenizerFast
5
  )
6
- from trl import SFTConfig, SFTTrainer
7
  from datasets import load_dataset, Dataset
8
  from tokenizers import ByteLevelBPETokenizer
9
  from huggingface_hub import HfApi
 
 
10
  from itertools import islice
11
 
12
- from logging import getLogger, StreamHandler, INFO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- logger = getLogger(__name__)
15
- logger.setLevel(INFO)
16
- handler = StreamHandler()
17
- logger.addHandler(handler)
18
 
19
- class Config:
20
- def __init__(self):
21
- # Model and training hyperparameters
22
- self.BATCH_SIZE = 16
23
- self.EPOCHS = 3
24
- self.LEARNING_RATE = 2e-4
25
- self.MAX_SEQ_LENGTH = 512
26
- self.VOCAB_SIZE = 32000
27
- self.FP16 = True
28
- self.WEIGHT_DECAY = 1e-3
29
- self.GRADIENT_ACCUMULATION_STEPS = self.BATCH_SIZE // 4
30
-
31
- # Dataset configurations
32
- self.INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
33
- self.INSTRUCT_DATASET = "nroggendorff/elephant"
34
- self.SHARD_SIZE = int(2e+5)
35
-
36
- # Output and repo settings
37
- self.OUTPUT_REPO = "nroggendorff/smallama"
38
- self.PUSH_TO_HUB = True
39
- self.INSTRUCT_FINETUNE_BOOL = False
40
-
41
- # Training steps and warmup
42
- self.FACTOR = 12 ** 3 // 3
43
- self.TOTAL_STEPS = (self.SHARD_SIZE * self.EPOCHS) // (self.BATCH_SIZE * self.GRADIENT_ACCUMULATION_STEPS)
44
- self.WARMUP_STEPS = int(self.TOTAL_STEPS * 0.1)
45
-
46
- # Initial state for shard offset
47
- self.INIT = 0
48
-
49
- # ignore
50
- self.getConfig = lambda: self._args()
51
-
52
- # @staticmethod
53
- def _args(self):
54
- return SFTConfig(
55
- output_dir="model",
56
- num_train_epochs=self.EPOCHS,
57
- per_device_train_batch_size=self.BATCH_SIZE,
58
- learning_rate=self.LEARNING_RATE,
59
- warmup_steps=self.WARMUP_STEPS,
60
- weight_decay=self.WEIGHT_DECAY,
61
- gradient_accumulation_steps=self.GRADIENT_ACCUMULATION_STEPS,
62
- fp16=self.FP16,
63
- save_steps=int(self.WARMUP_STEPS * 5),
64
- logging_steps=int(self.WARMUP_STEPS),
65
- save_total_limit=2,
66
- report_to="none",
67
- )
68
-
69
- config = Config().getConfig()
70
 
71
  class Space:
72
  def __init__(self):
73
  self.api = HfApi()
74
  self.pause = lambda: self.api.pause_space("nroggendorff/train-llama")
75
 
76
- space = Space()
77
-
78
  class FineError(Exception):
79
- def __init__(self, message="Training completed successfully."):
80
  self.message = message
81
  super().__init__(self.message)
82
 
83
- def load_data(dataset_name: str, split: str, shard_size: int, init_offset: int = 0) -> Dataset:
84
- dataset = load_dataset(dataset_name, split=split, streaming=True)
85
- shard_start = init_offset * shard_size
86
- data_list = list(islice(dataset, shard_start, shard_start + shard_size))
87
- return Dataset.from_dict({'text': [example.get('text', '') for example in data_list]})
88
-
89
- def encode_decode(texts, tokenizer):
90
- if tokenizer.pad_token is None:
91
- tokenizer.pad_token = tokenizer.eos_token
92
- tokenized_texts = tokenizer(
93
- texts, padding="max_length", truncation=True, max_length=config.MAX_SEQ_LENGTH, return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
 
94
  ).input_ids
95
- return tokenizer.batch_decode(tokenized_texts) if tokenized_texts.dim() >= 1 else [tokenizer.pad_token * config.MAX_SEQ_LENGTH]
 
 
 
 
 
 
 
 
 
96
 
97
  def create_tokenizer(training_corpus):
98
  tokenizer = ByteLevelBPETokenizer()
99
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
100
- tokenizer.train_from_iterator(training_corpus, vocab_size=config.VOCAB_SIZE, min_frequency=2, special_tokens=special_tokens)
101
- return PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
 
 
 
 
 
 
102
 
103
- def load_tokenizer(repo: str):
104
- return AutoTokenizer.from_pretrained(repo)
105
 
106
  def get_training_corpus(dataset):
107
  for i in range(0, len(dataset['text']), 1000):
108
  yield dataset['text'][i : i + 1000]
109
 
110
- def format_prompts(examples, tokenizer, is_instructional):
111
  texts = []
112
  for text in examples['text']:
113
  if text and len(text.strip()) > 0:
114
- if is_instructional:
115
  conversation = []
116
  parts = text.split('<|end|>')
117
  for i in range(0, len(parts) - 1, 2):
@@ -119,22 +109,29 @@ def format_prompts(examples, tokenizer, is_instructional):
119
  response = parts[i + 1].replace("<|bot|>", "").strip()
120
  conversation.append({"role": "user", "content": prompt})
121
  conversation.append({"role": "assistant", "content": response})
122
- coded_text = tokenizer.code(tokenizer.apply_chat_template(conversation, tokenize=False))
 
123
  texts.append(coded_text)
124
  else:
125
  texts.append(tokenizer.bos_token + tokenizer.code(text) + tokenizer.eos_token)
126
- if not texts:
 
 
 
 
127
  raise ValueError("No valid texts found in examples for formatting.")
128
- return {'text': tokenizer.code(texts)}
 
 
129
 
130
  def create_model(tokenizer):
131
- model_config = LlamaConfig(
132
  vocab_size=tokenizer.vocab_size,
133
- hidden_size=config.FACTOR,
134
- intermediate_size=config.FACTOR * 4,
135
  num_hidden_layers=12,
136
  num_attention_heads=12,
137
- max_position_embeddings=config.MAX_SEQ_LENGTH,
138
  rms_norm_eps=1e-5,
139
  initializer_range=0.02,
140
  use_cache=True,
@@ -143,47 +140,177 @@ def create_model(tokenizer):
143
  eos_token_id=tokenizer.eos_token_id,
144
  tie_word_embeddings=False,
145
  )
146
- return LlamaForCausalLM(model_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
149
- dataset = dataset.map(
150
- lambda examples: format_prompts(examples, tokenizer, is_instructional),
151
- batched=True,
152
- remove_columns=dataset.column_names
153
  )
154
- trainer = SFTTrainer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  model=model,
156
  tokenizer=tokenizer,
157
- config=config,
158
- train_dataset=dataset
 
 
 
159
  )
160
- train_result = trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- if push_to_hub:
163
- repo_id = config.OUTPUT_REPO + "-it" if config.INSTRUCT_FINETUNE_BOOL else config.OUTPUT_REPO
164
- trainer.model.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
165
- trainer.tokenizer.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
 
 
 
 
 
166
  else:
167
- trainer.model.save_pretrained("model")
168
- trainer.tokenizer.save_pretrained("tokenizer")
169
-
170
- def main():
171
- dataset = load_data(config.INPUT_DATASET, "train", config.SHARD_SIZE, config.INIT)
172
- tokenizer = (
173
- load_tokenizer(config.OUTPUT_REPO)
174
- if config.INSTRUCT_FINETUNE_BOOL and config.INIT > 0
175
- else create_tokenizer(get_training_corpus(dataset))
176
- )
177
- model = (
178
- load_model()
179
- if config.INSTRUCT_FINETUNE_BOOL or config.INIT > 0
180
- else create_model(tokenizer)
181
- )
182
- train_model(model, tokenizer, dataset, config.PUSH_TO_HUB, config.INSTRUCT_FINETUNE_BOOL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  if __name__ == "__main__":
185
  try:
186
- main()
187
  except Exception as e:
188
- logger.error(f"{type(e).__name__}: {e}")
189
- space.pause()
 
1
+ import os
2
+ from sys import exit
3
+ import torch
4
  import trl
5
  from transformers import (
6
+ AutoTokenizer, LlamaConfig, AutoModelForCausalLM, LlamaForCausalLM,
7
+ TrainingArguments, PreTrainedTokenizerFast, AdamW, get_cosine_schedule_with_warmup
8
  )
 
9
  from datasets import load_dataset, Dataset
10
  from tokenizers import ByteLevelBPETokenizer
11
  from huggingface_hub import HfApi
12
+ from trl import SFTConfig, SFTTrainer
13
+ from torch.utils.data import DataLoader
14
  from itertools import islice
15
 
16
+ BATCH_SIZE = 16
17
+ EPOCHS = 3
18
+ LEARNING_RATE = 2e-4
19
+ FACTOR = 12 ** 3 // 3
20
+ MAX_SEQ_LENGTH = 512
21
+ VOCAB_SIZE = 32000
22
+ INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
23
+ INSTRUCT_DATASET = "nroggendorff/elephant"
24
+ OUTPUT_REPO = "nroggendorff/smallama"
25
+ INSTRUCT_FINETUNE_BOOL = False
26
+ INIT = 0
27
+ SHARD_SIZE = int(2e+5)
28
+ FP16 = True
29
+ WEIGHT_DECAY = 1e-3
30
+ GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // 4
31
 
32
+ PUSH_TO_HUB = True
 
 
 
33
 
34
+ total_steps = (SHARD_SIZE * EPOCHS) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)
35
+ WARMUP_STEPS = total_steps * 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class Space:
38
  def __init__(self):
39
  self.api = HfApi()
40
  self.pause = lambda: self.api.pause_space("nroggendorff/train-llama")
41
 
 
 
42
  class FineError(Exception):
43
+ def __init__(self, message="Script execution has completed."):
44
  self.message = message
45
  super().__init__(self.message)
46
 
47
+ def load_data():
48
+ if not INSTRUCT_FINETUNE_BOOL:
49
+ dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
50
+ else:
51
+ dataset = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
52
+
53
+ start = INIT * SHARD_SIZE
54
+ data_list = list(islice(dataset, start, start + SHARD_SIZE))
55
+
56
+ dataset = Dataset.from_dict({'text': [example['text'] for example in data_list]})
57
+ return dataset
58
+
59
+ def encode_decode(texts, tok):
60
+ if tok.pad_token is None:
61
+ tok.pad_token = tok.eos_token
62
+
63
+ tokenized_texts = tok(
64
+ texts,
65
+ padding="max_length",
66
+ truncation=True,
67
+ max_length=MAX_SEQ_LENGTH,
68
+ return_tensors="pt"
69
  ).input_ids
70
+
71
+ if tokenized_texts.dim() >= 1:
72
+ decoded_texts = tok.batch_decode(tokenized_texts)
73
+ else:
74
+ print('Found invalid entry in examples. Returning dummy..')
75
+ decoded_texts = [tokenizer.pad_token * MAX_SEQ_LENGTH]
76
+
77
+ islist = not len(decoded_texts) == 1
78
+
79
+ return decoded_texts if islist else decoded_texts[0]
80
 
81
  def create_tokenizer(training_corpus):
82
  tokenizer = ByteLevelBPETokenizer()
83
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
84
+ tokenizer.train_from_iterator(
85
+ training_corpus,
86
+ vocab_size=VOCAB_SIZE,
87
+ min_frequency=2,
88
+ special_tokens=special_tokens
89
+ )
90
+ fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
91
+ return fast_tokenizer
92
 
93
+ def load_tokenizer():
94
+ return AutoTokenizer.from_pretrained(OUTPUT_REPO + '-it' if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO)
95
 
96
  def get_training_corpus(dataset):
97
  for i in range(0, len(dataset['text']), 1000):
98
  yield dataset['text'][i : i + 1000]
99
 
100
+ def format_prompts(examples, tokenizer, isinst):
101
  texts = []
102
  for text in examples['text']:
103
  if text and len(text.strip()) > 0:
104
+ if isinst:
105
  conversation = []
106
  parts = text.split('<|end|>')
107
  for i in range(0, len(parts) - 1, 2):
 
109
  response = parts[i + 1].replace("<|bot|>", "").strip()
110
  conversation.append({"role": "user", "content": prompt})
111
  conversation.append({"role": "assistant", "content": response})
112
+ formatted_conversation = tokenizer.apply_chat_template(conversation, tokenize=False)
113
+ coded_text = tokenizer.code(formatted_conversation)
114
  texts.append(coded_text)
115
  else:
116
  texts.append(tokenizer.bos_token + tokenizer.code(text) + tokenizer.eos_token)
117
+ else:
118
+ print('Found empty entry in examples. Moving on..')
119
+ continue
120
+
121
+ if len(texts) == 0:
122
  raise ValueError("No valid texts found in examples for formatting.")
123
+
124
+ coded_texts = tokenizer.code(texts)
125
+ return {'text': coded_texts}
126
 
127
  def create_model(tokenizer):
128
+ config = LlamaConfig(
129
  vocab_size=tokenizer.vocab_size,
130
+ hidden_size=FACTOR,
131
+ intermediate_size=FACTOR * 4,
132
  num_hidden_layers=12,
133
  num_attention_heads=12,
134
+ max_position_embeddings=MAX_SEQ_LENGTH,
135
  rms_norm_eps=1e-5,
136
  initializer_range=0.02,
137
  use_cache=True,
 
140
  eos_token_id=tokenizer.eos_token_id,
141
  tie_word_embeddings=False,
142
  )
143
+ return LlamaForCausalLM(config)
144
+
145
+ def load_model():
146
+ return AutoModelForCausalLM.from_pretrained(OUTPUT_REPO + '-it' if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO)
147
+
148
+ def configure_tokenizer(tokenizer):
149
+ special_tokens = {
150
+ "bos_token": "<s>",
151
+ "eos_token": "</s>",
152
+ "unk_token": "<unk>",
153
+ "pad_token": "<pad>",
154
+ "mask_token": "<mask>",
155
+ "additional_special_tokens": []
156
+ }
157
+ if INSTRUCT_FINETUNE_BOOL:
158
+ special_tokens["additional_special_tokens"] = ["<|user|>", "<|bot|>", "<|end|>"]
159
+ tokenizer.add_special_tokens(special_tokens)
160
+
161
+ if INSTRUCT_FINETUNE_BOOL:
162
+ tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
163
+ tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
164
+
165
+ chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
166
+ tokenizer.chat_template = chat_template
167
+
168
+ tokenizer.code = lambda example: encode_decode(example, tokenizer)
169
+
170
+ def update_tokenizer(tokenizer, dataset, batch_size=1000):
171
+ existing_vocab = tokenizer.get_vocab()
172
+ oov_tokens = set()
173
+
174
+ for i in range(0, len(dataset['text']), batch_size):
175
+ batch = dataset['text'][i:i + batch_size]
176
+
177
+ for text in batch:
178
+ token_ids = tokenizer.encode(text, add_special_tokens=False)
179
+
180
+ for token_id in token_ids:
181
+ token = tokenizer.decode([token_id])
182
+ if token.strip() and token not in existing_vocab:
183
+ oov_tokens.add(token)
184
+
185
+ if oov_tokens:
186
+ num_added = tokenizer.add_tokens(list(oov_tokens))
187
+ return num_added
188
+
189
+ return 0
190
+
191
+ def train_model(model, tokenizer, dataset, push, isinst):
192
+ args = SFTConfig(
193
+ output_dir="model",
194
+ num_train_epochs=EPOCHS,
195
+ per_device_train_batch_size=BATCH_SIZE,
196
+ learning_rate=LEARNING_RATE,
197
+ optim="adamw_torch",
198
+ warmup_steps=WARMUP_STEPS,
199
+ weight_decay=WEIGHT_DECAY,
200
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
201
+ fp16=FP16,
202
+ save_steps=WARMUP_STEPS * 5,
203
+ logging_steps=WARMUP_STEPS,
204
+ eval_strategy="no",
205
+ report_to="no",
206
+ # eval_steps=WARMUP_STEPS,
207
+ save_total_limit=2,
208
+ )
209
 
210
+ optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=WEIGHT_DECAY)
211
+ scheduler = get_cosine_schedule_with_warmup(
212
+ optimizer,
213
+ num_warmup_steps=args.warmup_steps,
214
+ num_training_steps=total_steps
215
  )
216
+
217
+ dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
218
+
219
+ if 'text' not in dataset.column_names:
220
+ raise ValueError("Dataset transformation failed: 'text' column missing after mapping.")
221
+
222
+ print("Mapped dataset sample length:", len(dataset[0]['text']))
223
+
224
+ try:
225
+ test_input = tokenizer(
226
+ ["This is a test input."],
227
+ return_tensors="pt",
228
+ padding="max_length",
229
+ truncation=True,
230
+ max_length=MAX_SEQ_LENGTH
231
+ )
232
+ test_output = model(**test_input)
233
+ print("Model test output shape:", test_output.logits.shape)
234
+ except RuntimeError as e:
235
+ print(f"Error processing test batch: {e}")
236
+
237
+ trainer = trl.SFTTrainer(
238
  model=model,
239
  tokenizer=tokenizer,
240
+ args=args,
241
+ train_dataset=dataset,
242
+ # dataset_text_field='text',
243
+ max_seq_length=MAX_SEQ_LENGTH,
244
+ optimizers=(optimizer, scheduler)
245
  )
246
+
247
+ train = trainer.train()
248
+
249
+ trained_model = trainer.model
250
+ trained_tokenizer = trainer.tokenizer
251
+
252
+ if push:
253
+ repo_id = OUTPUT_REPO + "-it" if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO
254
+ msg = f"Training loss: {train.training_loss:.4f}"
255
+ trained_model.push_to_hub(repo_id, commit_message=msg, force=True)
256
+ trained_tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
257
+ else:
258
+ trained_model.save_pretrained("model")
259
+ trained_tokenizer.save_pretrained("tokenizer")
260
 
261
+ def main(push_to_hub=True, is_inst_finetune=False):
262
+ print("Loading Data..")
263
+ dataset = load_data()
264
+ print("Loaded data.")
265
+
266
+ if is_inst_finetune and INIT > 0:
267
+ print("Loading Tokenizer..")
268
+ tokenizer = load_tokenizer()
269
+ print("Loaded Tokenizer.")
270
  else:
271
+ print("Making Corpus..")
272
+ training_corpus = get_training_corpus(dataset)
273
+ print("Made Corpus.")
274
+
275
+ print("Making Tokenizer..")
276
+ tokenizer = create_tokenizer(training_corpus)
277
+ print(f"Made Tokenizer with size {len(tokenizer)}.")
278
+
279
+ # print("Adding Tokens..")
280
+ # num_new_tokens = update_tokenizer(tokenizer, dataset)
281
+ # print(f"Added {num_new_tokens} new tokens to the vocabulary")
282
+
283
+ if INIT == 0:
284
+ print("Adding Special Tokens..")
285
+ configure_tokenizer(tokenizer)
286
+ print("Added Tokens.")
287
+
288
+ if is_inst_finetune or INIT > 0:
289
+ print("Loading Model..")
290
+ model = load_model()
291
+ print("Loaded Model.")
292
+ else:
293
+ print("Creating Model..")
294
+ model = create_model(tokenizer)
295
+ print("Created Model.")
296
+
297
+ print(f"Tokenizer vocabulary size: {len(tokenizer)}")
298
+ print(f"Special tokens: {tokenizer.special_tokens_map}")
299
+
300
+ print("Resizing Token Embeddings..")
301
+ try:
302
+ model.resize_token_embeddings(len(tokenizer))
303
+ except RuntimeError as e:
304
+ raise RuntimeError(f"Error resizing token embeddings: {e}")
305
+ print("Resized Embeddings.")
306
+
307
+ print("Training Model..")
308
+ train_model(model, tokenizer, dataset, push_to_hub, is_inst_finetune)
309
+ raise FineError("Trained Model.")
310
 
311
  if __name__ == "__main__":
312
  try:
313
+ main(PUSH_TO_HUB, INSTRUCT_FINETUNE_BOOL)
314
  except Exception as e:
315
+ print(f'{type(e).__name__}: {e}')
316
+ Space().pause()