Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
@@ -8,7 +8,6 @@ from transformers import (
|
|
8 |
from datasets import load_dataset
|
9 |
from tokenizers import ByteLevelBPETokenizer
|
10 |
from torch.utils.data import DataLoader
|
11 |
-
from torch.cuda.amp import autocast, GradScaler
|
12 |
|
13 |
BATCH_SIZE = 8
|
14 |
EPOCHS = 1
|
@@ -21,7 +20,6 @@ INSTRUCT_DATASET = "nroggendorff/elephant"
|
|
21 |
OUTPUT_REPO = "nroggendorff/smallama"
|
22 |
INSTRUCT_FINETUNE_BOOL = False
|
23 |
INIT = 0
|
24 |
-
SHARD_SIZE = int(5e+5)
|
25 |
FP16 = True
|
26 |
WARMUP_STEPS = 1000
|
27 |
WEIGHT_DECAY = 0.01
|
@@ -32,11 +30,20 @@ NUM_WORKERS = 4
|
|
32 |
def load_data():
|
33 |
if not INSTRUCT_FINETUNE_BOOL:
|
34 |
dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
|
35 |
-
dataset =
|
36 |
else:
|
37 |
dataset = load_dataset(INSTRUCT_DATASET, split="train")
|
38 |
return dataset
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def create_tokenizer(training_corpus):
|
41 |
tokenizer = ByteLevelBPETokenizer()
|
42 |
special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
|
@@ -137,26 +144,30 @@ def train_model(model, tokenizer, dataset, push, isinst):
|
|
137 |
scheduler = get_cosine_schedule_with_warmup(
|
138 |
optimizer,
|
139 |
num_warmup_steps=args.warmup_steps,
|
140 |
-
num_training_steps=
|
141 |
)
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
147 |
)
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
if push:
|
153 |
repo_id = OUTPUT_REPO + "-it" if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO
|
154 |
-
msg = f"Training loss: {
|
155 |
-
model.push_to_hub(repo_id, commit_message=msg, force=True)
|
156 |
-
tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
|
157 |
else:
|
158 |
-
model.save_pretrained("model")
|
159 |
-
tokenizer.save_pretrained("tokenizer")
|
160 |
|
161 |
def main(push_to_hub=True, is_inst_finetune=False):
|
162 |
dataset = load_data()
|
|
|
8 |
from datasets import load_dataset
|
9 |
from tokenizers import ByteLevelBPETokenizer
|
10 |
from torch.utils.data import DataLoader
|
|
|
11 |
|
12 |
BATCH_SIZE = 8
|
13 |
EPOCHS = 1
|
|
|
20 |
OUTPUT_REPO = "nroggendorff/smallama"
|
21 |
INSTRUCT_FINETUNE_BOOL = False
|
22 |
INIT = 0
|
|
|
23 |
FP16 = True
|
24 |
WARMUP_STEPS = 1000
|
25 |
WEIGHT_DECAY = 0.01
|
|
|
30 |
def load_data():
|
31 |
if not INSTRUCT_FINETUNE_BOOL:
|
32 |
dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
|
33 |
+
dataset = custom_shard_stream(dataset)
|
34 |
else:
|
35 |
dataset = load_dataset(INSTRUCT_DATASET, split="train")
|
36 |
return dataset
|
37 |
|
38 |
+
def custom_shard_stream(dataset, shard_size=5e5, shard_index=0):
|
39 |
+
def shard_generator():
|
40 |
+
count = 0
|
41 |
+
for example in dataset:
|
42 |
+
if count % shard_size == shard_index:
|
43 |
+
yield example
|
44 |
+
count += 1
|
45 |
+
return shard_generator()
|
46 |
+
|
47 |
def create_tokenizer(training_corpus):
|
48 |
tokenizer = ByteLevelBPETokenizer()
|
49 |
special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
|
|
|
144 |
scheduler = get_cosine_schedule_with_warmup(
|
145 |
optimizer,
|
146 |
num_warmup_steps=args.warmup_steps,
|
147 |
+
num_training_steps=args.num_train_epochs
|
148 |
)
|
149 |
|
150 |
+
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
|
151 |
+
|
152 |
+
trainer = trl.SFTTrainer(
|
153 |
+
model=model,
|
154 |
+
tokenizer=tokenizer,
|
155 |
+
args=args,
|
156 |
+
train_dataset=dataset,
|
157 |
+
optimizers=(optimizer, scheduler),
|
158 |
+
max_seq_length=MAX_SEQ_LENGTH
|
159 |
)
|
160 |
+
|
161 |
+
train_result = trainer.train()
|
162 |
+
|
|
|
163 |
if push:
|
164 |
repo_id = OUTPUT_REPO + "-it" if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO
|
165 |
+
msg = f"Training loss: {train_result.training_loss:.4f}"
|
166 |
+
trainer.model.push_to_hub(repo_id, commit_message=msg, force=True)
|
167 |
+
trainer.tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
|
168 |
else:
|
169 |
+
trainer.model.save_pretrained("model")
|
170 |
+
trainer.tokenizer.save_pretrained("tokenizer")
|
171 |
|
172 |
def main(push_to_hub=True, is_inst_finetune=False):
|
173 |
dataset = load_data()
|