nroggendorff commited on
Commit
09118fb
·
verified ·
1 Parent(s): 172239e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +12 -6
train.py CHANGED
@@ -16,7 +16,9 @@ VOCAB_SIZE = 32000
16
  INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
17
  INSTRUCT_DATASET = "nroggendorff/elephant"
18
  OUTPUT_REPO = "nroggendorff/smallama"
19
- INSTRUCT_FINETUNE_BOOL = True
 
 
20
  FP16 = True
21
  WARMUP_STEPS = 0
22
  DECAY = 0
@@ -25,11 +27,12 @@ PUSH_TO_HUB = True
25
 
26
  def load_data():
27
  if not INSTRUCT_FINETUNE_BOOL:
28
- dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
29
- dataset = Dataset.from_generator(lambda: dataset.take(int(5e+6)))
 
30
  else:
31
- dataset = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
32
- dataset = Dataset.from_generator(lambda: dataset.take(int(5e+6)))
33
  return dataset
34
 
35
  def create_tokenizer(training_corpus):
@@ -182,7 +185,10 @@ def main(push_to_hub=True, is_inst_finetune=False):
182
  model.resize_token_embeddings(len(tokenizer))
183
  train_model(model, tokenizer, dataset, push_to_hub, True)
184
  else:
185
- model = create_model(tokenizer)
 
 
 
186
  train_model(model, tokenizer, dataset, push_to_hub, False)
187
 
188
  if __name__ == "__main__":
 
16
  INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
17
  INSTRUCT_DATASET = "nroggendorff/elephant"
18
  OUTPUT_REPO = "nroggendorff/smallama"
19
+ INSTRUCT_FINETUNE_BOOL = False
20
+ INIT = 1 # /7
21
+ SHARD_SIZE = int(5e+6)
22
  FP16 = True
23
  WARMUP_STEPS = 0
24
  DECAY = 0
 
27
 
28
  def load_data():
29
  if not INSTRUCT_FINETUNE_BOOL:
30
+ dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train")#, streaming=True)
31
+ # dataset = Dataset.from_generator(lambda: dataset.take(int(5e+6)))
32
+ dataset = dataset.shard(num_shards=len(dataset) // SHARD_SIZE, index=INIT)
33
  else:
34
+ dataset = load_dataset(INSTRUCT_DATASET, split="train")#, streaming=True)
35
+ # dataset = Dataset.from_generator(lambda: dataset.take(int(5e+6)))
36
  return dataset
37
 
38
  def create_tokenizer(training_corpus):
 
185
  model.resize_token_embeddings(len(tokenizer))
186
  train_model(model, tokenizer, dataset, push_to_hub, True)
187
  else:
188
+ if INIT == 0:
189
+ model = create_model(tokenizer)
190
+ else:
191
+ model = load_model()
192
  train_model(model, tokenizer, dataset, push_to_hub, False)
193
 
194
  if __name__ == "__main__":