nroggendorff commited on
Commit
173ea16
·
verified ·
1 Parent(s): 2c9a0be

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -5
train.py CHANGED
@@ -8,11 +8,10 @@ from datasets import load_dataset
8
  from tokenizers import ByteLevelBPETokenizer
9
 
10
  MAX_SEQ_LENGTH = 128
11
- BATCH_SIZE = 1024
12
- EPOCHS = 16
13
  LEARNING_RATE = 1e-4
14
- FP16 = True
15
- FACTOR = 1
16
  VOCAB_SIZE = 3200
17
  INPUT_DATASET = "nroggendorff/elephant"
18
  OUTPUT_REPO = "smallama"
@@ -94,7 +93,6 @@ def train_model(model, tokenizer, dataset):
94
  num_train_epochs=EPOCHS,
95
  per_device_train_batch_size=BATCH_SIZE,
96
  learning_rate=LEARNING_RATE,
97
- fp16=FP16,
98
  optim="sgd"
99
  )
100
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
 
8
  from tokenizers import ByteLevelBPETokenizer
9
 
10
  MAX_SEQ_LENGTH = 128
11
+ BATCH_SIZE = 16
12
+ EPOCHS = 8
13
  LEARNING_RATE = 1e-4
14
+ FACTOR = 1024
 
15
  VOCAB_SIZE = 3200
16
  INPUT_DATASET = "nroggendorff/elephant"
17
  OUTPUT_REPO = "smallama"
 
93
  num_train_epochs=EPOCHS,
94
  per_device_train_batch_size=BATCH_SIZE,
95
  learning_rate=LEARNING_RATE,
 
96
  optim="sgd"
97
  )
98
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)