nroggendorff commited on
Commit
82312a4
·
verified ·
1 Parent(s): 9982b19

Update train.py

Browse files

its gonna take a while to fix all of these issues

Files changed (1) hide show
  1. train.py +77 -62
train.py CHANGED
@@ -13,26 +13,57 @@ from trl import SFTConfig, SFTTrainer
13
  from torch.utils.data import DataLoader
14
  from itertools import islice
15
 
16
- BATCH_SIZE = 16
17
- EPOCHS = 3
18
- LEARNING_RATE = 2e-4
19
- FACTOR = 12 ** 3 // 3
20
- MAX_SEQ_LENGTH = 512
21
- VOCAB_SIZE = 32000
22
- INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
23
- INSTRUCT_DATASET = "nroggendorff/elephant"
24
- OUTPUT_REPO = "nroggendorff/smallama"
25
- INSTRUCT_FINETUNE_BOOL = False
26
- INIT = 0
27
- SHARD_SIZE = int(2e+5)
28
- FP16 = True
29
- WEIGHT_DECAY = 1e-3
30
- GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // 4
31
-
32
- PUSH_TO_HUB = True
33
-
34
- total_steps = (SHARD_SIZE * EPOCHS) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)
35
- WARMUP_STEPS = total_steps * 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class Space:
38
  def __init__(self):
@@ -45,13 +76,13 @@ class FineError(Exception):
45
  super().__init__(self.message)
46
 
47
  def load_data():
48
- if not INSTRUCT_FINETUNE_BOOL:
49
- dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
50
  else:
51
- dataset = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
52
 
53
- start = INIT * SHARD_SIZE
54
- data_list = list(islice(dataset, start, start + SHARD_SIZE))
55
 
56
  dataset = Dataset.from_dict({'text': [example['text'] for example in data_list]})
57
  return dataset
@@ -64,7 +95,7 @@ def encode_decode(texts, tok):
64
  texts,
65
  padding="max_length",
66
  truncation=True,
67
- max_length=MAX_SEQ_LENGTH,
68
  return_tensors="pt"
69
  ).input_ids
70
 
@@ -72,7 +103,7 @@ def encode_decode(texts, tok):
72
  decoded_texts = tok.batch_decode(tokenized_texts)
73
  else:
74
  print('Found invalid entry in examples. Returning dummy..')
75
- decoded_texts = [tokenizer.pad_token * MAX_SEQ_LENGTH]
76
 
77
  islist = not len(decoded_texts) == 1
78
 
@@ -83,7 +114,7 @@ def create_tokenizer(training_corpus):
83
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
84
  tokenizer.train_from_iterator(
85
  training_corpus,
86
- vocab_size=VOCAB_SIZE,
87
  min_frequency=2,
88
  special_tokens=special_tokens
89
  )
@@ -91,7 +122,7 @@ def create_tokenizer(training_corpus):
91
  return fast_tokenizer
92
 
93
  def load_tokenizer():
94
- return AutoTokenizer.from_pretrained(OUTPUT_REPO + '-it' if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO)
95
 
96
  def get_training_corpus(dataset):
97
  for i in range(0, len(dataset['text']), 1000):
@@ -125,13 +156,13 @@ def format_prompts(examples, tokenizer, isinst):
125
  return {'text': coded_texts}
126
 
127
  def create_model(tokenizer):
128
- config = LlamaConfig(
129
  vocab_size=tokenizer.vocab_size,
130
- hidden_size=FACTOR,
131
- intermediate_size=FACTOR * 4,
132
  num_hidden_layers=12,
133
  num_attention_heads=12,
134
- max_position_embeddings=MAX_SEQ_LENGTH,
135
  rms_norm_eps=1e-5,
136
  initializer_range=0.02,
137
  use_cache=True,
@@ -140,10 +171,10 @@ def create_model(tokenizer):
140
  eos_token_id=tokenizer.eos_token_id,
141
  tie_word_embeddings=False,
142
  )
143
- return LlamaForCausalLM(config)
144
 
145
  def load_model():
146
- return AutoModelForCausalLM.from_pretrained(OUTPUT_REPO + '-it' if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO)
147
 
148
  def configure_tokenizer(tokenizer):
149
  special_tokens = {
@@ -154,11 +185,11 @@ def configure_tokenizer(tokenizer):
154
  "mask_token": "<mask>",
155
  "additional_special_tokens": []
156
  }
157
- if INSTRUCT_FINETUNE_BOOL:
158
  special_tokens["additional_special_tokens"] = ["<|user|>", "<|bot|>", "<|end|>"]
159
  tokenizer.add_special_tokens(special_tokens)
160
 
161
- if INSTRUCT_FINETUNE_BOOL:
162
  tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
163
  tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
164
 
@@ -189,25 +220,9 @@ def update_tokenizer(tokenizer, dataset, batch_size=1000):
189
  return 0
190
 
191
  def train_model(model, tokenizer, dataset, push, isinst):
192
- args = SFTConfig(
193
- output_dir="model",
194
- num_train_epochs=EPOCHS,
195
- per_device_train_batch_size=BATCH_SIZE,
196
- learning_rate=LEARNING_RATE,
197
- optim="adamw_torch",
198
- warmup_steps=WARMUP_STEPS,
199
- weight_decay=WEIGHT_DECAY,
200
- gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
201
- fp16=FP16,
202
- save_steps=WARMUP_STEPS * 5,
203
- logging_steps=WARMUP_STEPS,
204
- eval_strategy="no",
205
- report_to="no",
206
- # eval_steps=WARMUP_STEPS,
207
- save_total_limit=2,
208
- )
209
 
210
- optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=WEIGHT_DECAY)
211
  scheduler = get_cosine_schedule_with_warmup(
212
  optimizer,
213
  num_warmup_steps=args.warmup_steps,
@@ -234,13 +249,13 @@ def train_model(model, tokenizer, dataset, push, isinst):
234
  except RuntimeError as e:
235
  print(f"Error processing test batch: {e}")
236
 
237
- trainer = trl.SFTTrainer(
238
  model=model,
239
  tokenizer=tokenizer,
240
  args=args,
241
  train_dataset=dataset,
242
  # dataset_text_field='text',
243
- max_seq_length=MAX_SEQ_LENGTH,
244
  optimizers=(optimizer, scheduler)
245
  )
246
 
@@ -250,7 +265,7 @@ def train_model(model, tokenizer, dataset, push, isinst):
250
  trained_tokenizer = trainer.tokenizer
251
 
252
  if push:
253
- repo_id = OUTPUT_REPO + "-it" if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO
254
  msg = f"Training loss: {train.training_loss:.4f}"
255
  trained_model.push_to_hub(repo_id, commit_message=msg, force=True)
256
  trained_tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
@@ -258,12 +273,12 @@ def train_model(model, tokenizer, dataset, push, isinst):
258
  trained_model.save_pretrained("model")
259
  trained_tokenizer.save_pretrained("tokenizer")
260
 
261
- def main(push_to_hub=True, is_inst_finetune=False):
262
  print("Loading Data..")
263
  dataset = load_data()
264
  print("Loaded data.")
265
 
266
- if is_inst_finetune and INIT > 0:
267
  print("Loading Tokenizer..")
268
  tokenizer = load_tokenizer()
269
  print("Loaded Tokenizer.")
@@ -285,7 +300,7 @@ def main(push_to_hub=True, is_inst_finetune=False):
285
  configure_tokenizer(tokenizer)
286
  print("Added Tokens.")
287
 
288
- if is_inst_finetune or INIT > 0:
289
  print("Loading Model..")
290
  model = load_model()
291
  print("Loaded Model.")
@@ -310,7 +325,7 @@ def main(push_to_hub=True, is_inst_finetune=False):
310
 
311
  if __name__ == "__main__":
312
  try:
313
- main(PUSH_TO_HUB, INSTRUCT_FINETUNE_BOOL)
314
  except Exception as e:
315
  print(f'{type(e).__name__}: {e}')
316
  Space().pause()
 
13
  from torch.utils.data import DataLoader
14
  from itertools import islice
15
 
16
+ class Config:
17
+ def __init__(self):
18
+ # Model and training hyperparameters
19
+ self.BATCH_SIZE = 16
20
+ self.EPOCHS = 3
21
+ self.LEARNING_RATE = 2e-4
22
+ self.MAX_SEQ_LENGTH = 512
23
+ self.VOCAB_SIZE = 32000
24
+ self.FP16 = True
25
+ self.WEIGHT_DECAY = 1e-3
26
+ self.GRADIENT_ACCUMULATION_STEPS = self.BATCH_SIZE // 4
27
+
28
+ # Dataset configurations
29
+ self.INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
30
+ self.INSTRUCT_DATASET = "nroggendorff/elephant"
31
+ self.SHARD_SIZE = int(2e+5)
32
+
33
+ # Output and repo settings
34
+ self.OUTPUT_REPO = "nroggendorff/smallama"
35
+ self.PUSH_TO_HUB = True
36
+ self.INSTRUCT_FINETUNE_BOOL = False
37
+
38
+ # Training steps and warmup
39
+ self.FACTOR = 12 ** 3 // 3
40
+ self.TOTAL_STEPS = (self.SHARD_SIZE * self.EPOCHS) // (self.BATCH_SIZE * self.GRADIENT_ACCUMULATION_STEPS)
41
+ self.WARMUP_STEPS = int(self.TOTAL_STEPS * 0.1)
42
+
43
+ # Initial state for shard offset
44
+ self.INIT = 0
45
+
46
+ # ignore
47
+ self.getConfig = lambda: self._args()
48
+
49
+ # @staticmethod
50
+ def _args(self):
51
+ return SFTConfig(
52
+ output_dir="model",
53
+ num_train_epochs=self.EPOCHS,
54
+ per_device_train_batch_size=self.BATCH_SIZE,
55
+ learning_rate=self.LEARNING_RATE,
56
+ warmup_steps=self.WARMUP_STEPS,
57
+ weight_decay=self.WEIGHT_DECAY,
58
+ gradient_accumulation_steps=self.GRADIENT_ACCUMULATION_STEPS,
59
+ fp16=self.FP16,
60
+ save_steps=int(self.WARMUP_STEPS * 5),
61
+ logging_steps=int(self.WARMUP_STEPS),
62
+ save_total_limit=2,
63
+ report_to="none",
64
+ )
65
+
66
+ config = Config()
67
 
68
  class Space:
69
  def __init__(self):
 
76
  super().__init__(self.message)
77
 
78
  def load_data():
79
+ if not config.INSTRUCT_FINETUNE_BOOL:
80
+ dataset = load_dataset(config.INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
81
  else:
82
+ dataset = load_dataset(config.INSTRUCT_DATASET, split="train", streaming=True)
83
 
84
+ start = config.INIT * config.SHARD_SIZE
85
+ data_list = list(islice(dataset, start, start + config.SHARD_SIZE))
86
 
87
  dataset = Dataset.from_dict({'text': [example['text'] for example in data_list]})
88
  return dataset
 
95
  texts,
96
  padding="max_length",
97
  truncation=True,
98
+ max_length=config.MAX_SEQ_LENGTH,
99
  return_tensors="pt"
100
  ).input_ids
101
 
 
103
  decoded_texts = tok.batch_decode(tokenized_texts)
104
  else:
105
  print('Found invalid entry in examples. Returning dummy..')
106
+ decoded_texts = [tok.pad_token * config.MAX_SEQ_LENGTH]
107
 
108
  islist = not len(decoded_texts) == 1
109
 
 
114
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
115
  tokenizer.train_from_iterator(
116
  training_corpus,
117
+ vocab_size=config.VOCAB_SIZE,
118
  min_frequency=2,
119
  special_tokens=special_tokens
120
  )
 
122
  return fast_tokenizer
123
 
124
  def load_tokenizer():
125
+ return AutoTokenizer.from_pretrained(config.OUTPUT_REPO + '-it' if config.INSTRUCT_FINETUNE_BOOL else config.OUTPUT_REPO)
126
 
127
  def get_training_corpus(dataset):
128
  for i in range(0, len(dataset['text']), 1000):
 
156
  return {'text': coded_texts}
157
 
158
  def create_model(tokenizer):
159
+ model_config = LlamaConfig(
160
  vocab_size=tokenizer.vocab_size,
161
+ hidden_size=config.FACTOR,
162
+ intermediate_size=config.FACTOR * 4,
163
  num_hidden_layers=12,
164
  num_attention_heads=12,
165
+ max_position_embeddings=config.MAX_SEQ_LENGTH,
166
  rms_norm_eps=1e-5,
167
  initializer_range=0.02,
168
  use_cache=True,
 
171
  eos_token_id=tokenizer.eos_token_id,
172
  tie_word_embeddings=False,
173
  )
174
+ return LlamaForCausalLM(model_config)
175
 
176
  def load_model():
177
+ return AutoModelForCausalLM.from_pretrained(config.OUTPUT_REPO + '-it' if config.INSTRUCT_FINETUNE_BOOL else config.OUTPUT_REPO)
178
 
179
  def configure_tokenizer(tokenizer):
180
  special_tokens = {
 
185
  "mask_token": "<mask>",
186
  "additional_special_tokens": []
187
  }
188
+ if config.INSTRUCT_FINETUNE_BOOL:
189
  special_tokens["additional_special_tokens"] = ["<|user|>", "<|bot|>", "<|end|>"]
190
  tokenizer.add_special_tokens(special_tokens)
191
 
192
+ if config.INSTRUCT_FINETUNE_BOOL:
193
  tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
194
  tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
195
 
 
220
  return 0
221
 
222
  def train_model(model, tokenizer, dataset, push, isinst):
223
+ args = config.getConfig()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=config.WEIGHT_DECAY)
226
  scheduler = get_cosine_schedule_with_warmup(
227
  optimizer,
228
  num_warmup_steps=args.warmup_steps,
 
249
  except RuntimeError as e:
250
  print(f"Error processing test batch: {e}")
251
 
252
+ trainer = SFTTrainer(
253
  model=model,
254
  tokenizer=tokenizer,
255
  args=args,
256
  train_dataset=dataset,
257
  # dataset_text_field='text',
258
+ max_seq_length=config.MAX_SEQ_LENGTH,
259
  optimizers=(optimizer, scheduler)
260
  )
261
 
 
265
  trained_tokenizer = trainer.tokenizer
266
 
267
  if push:
268
+ repo_id = config.OUTPUT_REPO + "-it" if config.INSTRUCT_FINETUNE_BOOL else config.OUTPUT_REPO
269
  msg = f"Training loss: {train.training_loss:.4f}"
270
  trained_model.push_to_hub(repo_id, commit_message=msg, force=True)
271
  trained_tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
 
273
  trained_model.save_pretrained("model")
274
  trained_tokenizer.save_pretrained("tokenizer")
275
 
276
+ def main(push_to_hub=True, is_inst_finetune=config.INSTRUCT_FINETUNE_BOOL):
277
  print("Loading Data..")
278
  dataset = load_data()
279
  print("Loaded data.")
280
 
281
+ if is_inst_finetune and config.INIT > 0:
282
  print("Loading Tokenizer..")
283
  tokenizer = load_tokenizer()
284
  print("Loaded Tokenizer.")
 
300
  configure_tokenizer(tokenizer)
301
  print("Added Tokens.")
302
 
303
+ if is_inst_finetune or config.INIT > 0:
304
  print("Loading Model..")
305
  model = load_model()
306
  print("Loaded Model.")
 
325
 
326
  if __name__ == "__main__":
327
  try:
328
+ main()
329
  except Exception as e:
330
  print(f'{type(e).__name__}: {e}')
331
  Space().pause()