nroggendorff commited on
Commit
a3dd2de
·
verified ·
1 Parent(s): 387158c

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +21 -4
train.py CHANGED
@@ -3,18 +3,22 @@ import os
3
  import torch
4
  import trl
5
 
6
- from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast
7
  from datasets import load_dataset
8
  from tokenizers import ByteLevelBPETokenizer
9
 
10
  MAX_SEQ_LENGTH = 128
11
- BATCH_SIZE = 1024
12
  EPOCHS = 10
13
- LEARNING_RATE = 1e-5
14
  FACTOR = 4
15
  VOCAB_SIZE = 32000
16
  INPUT_DATASET = "nroggendorff/oak"
17
  OUTPUT_REPO = "smallama"
 
 
 
 
18
  PUSH_TO_HUB = True
19
 
20
  def load_data():
@@ -94,8 +98,21 @@ def train_model(model, tokenizer, dataset, push):
94
  num_train_epochs=EPOCHS,
95
  per_device_train_batch_size=BATCH_SIZE,
96
  learning_rate=LEARNING_RATE,
97
- optim="sgd"
 
 
 
 
 
98
  )
 
 
 
 
 
 
 
 
99
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
100
  trainer = trl.SFTTrainer(
101
  model=model,
 
3
  import torch
4
  import trl
5
 
6
+ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW, get_linear_schedule_with_warmup
7
  from datasets import load_dataset
8
  from tokenizers import ByteLevelBPETokenizer
9
 
10
  MAX_SEQ_LENGTH = 128
11
+ BATCH_SIZE = 512
12
  EPOCHS = 10
13
+ LEARNING_RATE = 2e-5
14
  FACTOR = 4
15
  VOCAB_SIZE = 32000
16
  INPUT_DATASET = "nroggendorff/oak"
17
  OUTPUT_REPO = "smallama"
18
+ FP16 = True
19
+ WARMUP_STEPS = 500
20
+ DECAY = 0.01
21
+ GRADIENT_ACCUMILATION_STEPS = 4
22
  PUSH_TO_HUB = True
23
 
24
  def load_data():
 
98
  num_train_epochs=EPOCHS,
99
  per_device_train_batch_size=BATCH_SIZE,
100
  learning_rate=LEARNING_RATE,
101
+ optim="adamw_torch",
102
+ warmup_steps=WARMUP_STEPS,
103
+ weight_decay=DECAY,
104
+ gradient_accumulation_steps=GRADIENT_ACCUMILATION_STEPS,
105
+ fp16=True,
106
+ evaluation_strategy="steps"
107
  )
108
+
109
+ optimizer = AdamW(model.parameters(), lr=args.learning_rate)
110
+ scheduler = get_linear_schedule_with_warmup(
111
+ optimizer,
112
+ num_warmup_steps=args.warmup_steps,
113
+ num_training_steps=len(dataset) * args.num_train_epochs // args.gradient_accumulation_steps
114
+ )
115
+
116
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
117
  trainer = trl.SFTTrainer(
118
  model=model,