nroggendorff commited on
Commit
5720fe4
·
verified ·
1 Parent(s): 382ddc1

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +21 -17
train.py CHANGED
@@ -13,23 +13,25 @@ from torch.utils.data import DataLoader
13
  from itertools import islice
14
 
15
  BATCH_SIZE = 16
16
- EPOCHS = 1
17
  LEARNING_RATE = 2e-4
18
  FACTOR = 12 ** 3 // 3
19
- MAX_SEQ_LENGTH = 128
20
  VOCAB_SIZE = 32000
21
  INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
22
  INSTRUCT_DATASET = "nroggendorff/elephant"
23
  OUTPUT_REPO = "nroggendorff/smallama"
24
  INSTRUCT_FINETUNE_BOOL = False
25
  INIT = 0
26
- SHARD_SIZE = int(2e+6)
27
  FP16 = True
28
- WARMUP_STEPS = 50
29
  WEIGHT_DECAY = 1e-3
30
- GRADIENT_ACCUMULATION_STEPS = 2
 
31
  PUSH_TO_HUB = True
32
 
 
 
33
  class Space:
34
  def __init__(self):
35
  self.api = HfApi()
@@ -38,15 +40,17 @@ class Space:
38
  space = Space()
39
 
40
  def load_data():
41
- if not INSTRUCT_FINETUNE_BOOL:
42
- dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
43
- start = INIT * SHARD_SIZE
44
- dataset = Dataset.from_dict({'text': [example['text'] for example in islice(dataset, start, start + SHARD_SIZE)]})
45
- else:
46
- dataset = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
47
  start = INIT * SHARD_SIZE
48
- dataset = Dataset.from_dict({'text': [example['text'] for example in islice(dataset, start, start + SHARD_SIZE)]})
49
- return dataset
 
 
50
 
51
  def create_tokenizer(training_corpus):
52
  tokenizer = ByteLevelBPETokenizer()
@@ -158,10 +162,10 @@ def train_model(model, tokenizer, dataset, push, isinst):
158
  weight_decay=WEIGHT_DECAY,
159
  gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
160
  fp16=FP16,
161
- save_steps=int(1e+10),
162
- logging_steps=5000,
163
  evaluation_strategy="no",
164
- eval_steps=2000,
165
  save_total_limit=2,
166
  )
167
 
@@ -169,7 +173,7 @@ def train_model(model, tokenizer, dataset, push, isinst):
169
  scheduler = get_cosine_schedule_with_warmup(
170
  optimizer,
171
  num_warmup_steps=args.warmup_steps,
172
- num_training_steps=(len(dataset) // args.per_device_train_batch_size) * args.num_train_epochs
173
  )
174
 
175
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
 
13
  from itertools import islice
14
 
15
  BATCH_SIZE = 16
16
+ EPOCHS = 3
17
  LEARNING_RATE = 2e-4
18
  FACTOR = 12 ** 3 // 3
19
+ MAX_SEQ_LENGTH = 512
20
  VOCAB_SIZE = 32000
21
  INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
22
  INSTRUCT_DATASET = "nroggendorff/elephant"
23
  OUTPUT_REPO = "nroggendorff/smallama"
24
  INSTRUCT_FINETUNE_BOOL = False
25
  INIT = 0
26
+ SHARD_SIZE = int(2e+5)
27
  FP16 = True
 
28
  WEIGHT_DECAY = 1e-3
29
+ GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // 4
30
+ WARMUP_STEPS = ((SHARD_SIZE // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)) * EPOCHS) // 10
31
  PUSH_TO_HUB = True
32
 
33
+ total_steps = WARMUP_STEPS * 10
34
+
35
  class Space:
36
  def __init__(self):
37
  self.api = HfApi()
 
40
  space = Space()
41
 
42
  def load_data():
43
+ try:
44
+ if not INSTRUCT_FINETUNE_BOOL:
45
+ dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
46
+ else:
47
+ dataset = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
48
+
49
  start = INIT * SHARD_SIZE
50
+ data_list = list(islice(dataset, start, start + SHARD_SIZE))
51
+
52
+ dataset = Dataset.from_dict({'text': [example['text'] for example in data_list]})
53
+ return dataset
54
 
55
  def create_tokenizer(training_corpus):
56
  tokenizer = ByteLevelBPETokenizer()
 
162
  weight_decay=WEIGHT_DECAY,
163
  gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
164
  fp16=FP16,
165
+ save_steps=WARMUP_STEPS,
166
+ logging_steps=WARMUP_STEPS,
167
  evaluation_strategy="no",
168
+ eval_steps=1,
169
  save_total_limit=2,
170
  )
171
 
 
173
  scheduler = get_cosine_schedule_with_warmup(
174
  optimizer,
175
  num_warmup_steps=args.warmup_steps,
176
+ num_training_steps=total_steps
177
  )
178
 
179
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)