nroggendorff commited on
Commit
d47b527
·
verified ·
1 Parent(s): 45fd1bb

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +2 -6
train.py CHANGED
@@ -8,7 +8,6 @@ from transformers import (
8
  from datasets import load_dataset, Dataset
9
  from tokenizers import ByteLevelBPETokenizer
10
  from torch.utils.data import DataLoader
11
- from torch.cuda.amp import autocast, GradScaler
12
  from itertools import islice
13
 
14
  BATCH_SIZE = 16
@@ -26,7 +25,7 @@ SHARD_SIZE = int(15e+5)
26
  FP16 = True
27
  WARMUP_STEPS = 0
28
  WEIGHT_DECAY = 0
29
- GRADIENT_ACCUMULATION_STEPS = 1#BATCH_SIZE // 4
30
  PUSH_TO_HUB = True
31
 
32
  def load_data():
@@ -155,8 +154,6 @@ def train_model(model, tokenizer, dataset, push, isinst):
155
  save_total_limit=2,
156
  )
157
 
158
- # dataset = dataset.shard(num_shards=len(dataset) // SHARD_SIZE, index=INIT)
159
-
160
  optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=WEIGHT_DECAY)
161
  scheduler = get_cosine_schedule_with_warmup(
162
  optimizer,
@@ -212,5 +209,4 @@ def main(push_to_hub=True, is_inst_finetune=False):
212
  train_model(model, tokenizer, dataset, push_to_hub, is_inst_finetune)
213
 
214
  if __name__ == "__main__":
215
- main(PUSH_TO_HUB, INSTRUCT_FINETUNE_BOOL)
216
- raise Exception("Done baking!")
 
8
  from datasets import load_dataset, Dataset
9
  from tokenizers import ByteLevelBPETokenizer
10
  from torch.utils.data import DataLoader
 
11
  from itertools import islice
12
 
13
  BATCH_SIZE = 16
 
25
  FP16 = True
26
  WARMUP_STEPS = 0
27
  WEIGHT_DECAY = 0
28
+ GRADIENT_ACCUMULATION_STEPS = 1
29
  PUSH_TO_HUB = True
30
 
31
  def load_data():
 
154
  save_total_limit=2,
155
  )
156
 
 
 
157
  optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=WEIGHT_DECAY)
158
  scheduler = get_cosine_schedule_with_warmup(
159
  optimizer,
 
209
  train_model(model, tokenizer, dataset, push_to_hub, is_inst_finetune)
210
 
211
  if __name__ == "__main__":
212
+ main(PUSH_TO_HUB, INSTRUCT_FINETUNE_BOOL)