nroggendorff commited on
Commit
721bf9a
·
verified ·
1 Parent(s): 93fda42

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +28 -17
train.py CHANGED
@@ -8,7 +8,6 @@ from transformers import (
8
  from datasets import load_dataset
9
  from tokenizers import ByteLevelBPETokenizer
10
  from torch.utils.data import DataLoader
11
- from torch.cuda.amp import autocast, GradScaler
12
 
13
  BATCH_SIZE = 8
14
  EPOCHS = 1
@@ -21,7 +20,6 @@ INSTRUCT_DATASET = "nroggendorff/elephant"
21
  OUTPUT_REPO = "nroggendorff/smallama"
22
  INSTRUCT_FINETUNE_BOOL = False
23
  INIT = 0
24
- SHARD_SIZE = int(5e+5)
25
  FP16 = True
26
  WARMUP_STEPS = 1000
27
  WEIGHT_DECAY = 0.01
@@ -32,11 +30,20 @@ NUM_WORKERS = 4
32
  def load_data():
33
  if not INSTRUCT_FINETUNE_BOOL:
34
  dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
35
- dataset = dataset.shard(num_shards=len(dataset) // SHARD_SIZE, index=INIT)
36
  else:
37
  dataset = load_dataset(INSTRUCT_DATASET, split="train")
38
  return dataset
39
 
 
 
 
 
 
 
 
 
 
40
  def create_tokenizer(training_corpus):
41
  tokenizer = ByteLevelBPETokenizer()
42
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
@@ -137,26 +144,30 @@ def train_model(model, tokenizer, dataset, push, isinst):
137
  scheduler = get_cosine_schedule_with_warmup(
138
  optimizer,
139
  num_warmup_steps=args.warmup_steps,
140
- num_training_steps=(len(dataset) // args.per_device_train_batch_size) * args.num_train_epochs
141
  )
142
 
143
- dataloader = DataLoader(
144
- dataset,
145
- batch_size=BATCH_SIZE,
146
- num_workers=NUM_WORKERS
 
 
 
 
 
147
  )
148
-
149
- for batch in dataloader:
150
- batch = format_prompts(batch, tokenizer, isinst)
151
-
152
  if push:
153
  repo_id = OUTPUT_REPO + "-it" if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO
154
- msg = f"Training loss: {train.training_loss:.4f}"
155
- model.push_to_hub(repo_id, commit_message=msg, force=True)
156
- tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
157
  else:
158
- model.save_pretrained("model")
159
- tokenizer.save_pretrained("tokenizer")
160
 
161
  def main(push_to_hub=True, is_inst_finetune=False):
162
  dataset = load_data()
 
8
  from datasets import load_dataset
9
  from tokenizers import ByteLevelBPETokenizer
10
  from torch.utils.data import DataLoader
 
11
 
12
  BATCH_SIZE = 8
13
  EPOCHS = 1
 
20
  OUTPUT_REPO = "nroggendorff/smallama"
21
  INSTRUCT_FINETUNE_BOOL = False
22
  INIT = 0
 
23
  FP16 = True
24
  WARMUP_STEPS = 1000
25
  WEIGHT_DECAY = 0.01
 
30
  def load_data():
31
  if not INSTRUCT_FINETUNE_BOOL:
32
  dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
33
+ dataset = custom_shard_stream(dataset)
34
  else:
35
  dataset = load_dataset(INSTRUCT_DATASET, split="train")
36
  return dataset
37
 
38
+ def custom_shard_stream(dataset, shard_size=5e5, shard_index=0):
39
+ def shard_generator():
40
+ count = 0
41
+ for example in dataset:
42
+ if count % shard_size == shard_index:
43
+ yield example
44
+ count += 1
45
+ return shard_generator()
46
+
47
  def create_tokenizer(training_corpus):
48
  tokenizer = ByteLevelBPETokenizer()
49
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
 
144
  scheduler = get_cosine_schedule_with_warmup(
145
  optimizer,
146
  num_warmup_steps=args.warmup_steps,
147
+ num_training_steps=args.num_train_epochs
148
  )
149
 
150
+ dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
151
+
152
+ trainer = trl.SFTTrainer(
153
+ model=model,
154
+ tokenizer=tokenizer,
155
+ args=args,
156
+ train_dataset=dataset,
157
+ optimizers=(optimizer, scheduler),
158
+ max_seq_length=MAX_SEQ_LENGTH
159
  )
160
+
161
+ train_result = trainer.train()
162
+
 
163
  if push:
164
  repo_id = OUTPUT_REPO + "-it" if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO
165
+ msg = f"Training loss: {train_result.training_loss:.4f}"
166
+ trainer.model.push_to_hub(repo_id, commit_message=msg, force=True)
167
+ trainer.tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
168
  else:
169
+ trainer.model.save_pretrained("model")
170
+ trainer.tokenizer.save_pretrained("tokenizer")
171
 
172
  def main(push_to_hub=True, is_inst_finetune=False):
173
  dataset = load_data()