Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -8,7 +8,6 @@ from transformers import (
|
|
8 |
from datasets import load_dataset, Dataset
|
9 |
from tokenizers import ByteLevelBPETokenizer
|
10 |
from torch.utils.data import DataLoader
|
11 |
-
from torch.cuda.amp import autocast, GradScaler
|
12 |
from itertools import islice
|
13 |
|
14 |
BATCH_SIZE = 16
|
@@ -26,7 +25,7 @@ SHARD_SIZE = int(15e+5)
|
|
26 |
FP16 = True
|
27 |
WARMUP_STEPS = 0
|
28 |
WEIGHT_DECAY = 0
|
29 |
-
GRADIENT_ACCUMULATION_STEPS = 1
|
30 |
PUSH_TO_HUB = True
|
31 |
|
32 |
def load_data():
|
@@ -155,8 +154,6 @@ def train_model(model, tokenizer, dataset, push, isinst):
|
|
155 |
save_total_limit=2,
|
156 |
)
|
157 |
|
158 |
-
# dataset = dataset.shard(num_shards=len(dataset) // SHARD_SIZE, index=INIT)
|
159 |
-
|
160 |
optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=WEIGHT_DECAY)
|
161 |
scheduler = get_cosine_schedule_with_warmup(
|
162 |
optimizer,
|
@@ -212,5 +209,4 @@ def main(push_to_hub=True, is_inst_finetune=False):
|
|
212 |
train_model(model, tokenizer, dataset, push_to_hub, is_inst_finetune)
|
213 |
|
214 |
if __name__ == "__main__":
|
215 |
-
main(PUSH_TO_HUB, INSTRUCT_FINETUNE_BOOL)
|
216 |
-
raise Exception("Done baking!")
|
|
|
8 |
from datasets import load_dataset, Dataset
|
9 |
from tokenizers import ByteLevelBPETokenizer
|
10 |
from torch.utils.data import DataLoader
|
|
|
11 |
from itertools import islice
|
12 |
|
13 |
BATCH_SIZE = 16
|
|
|
25 |
FP16 = True
|
26 |
WARMUP_STEPS = 0
|
27 |
WEIGHT_DECAY = 0
|
28 |
+
GRADIENT_ACCUMULATION_STEPS = 1
|
29 |
PUSH_TO_HUB = True
|
30 |
|
31 |
def load_data():
|
|
|
154 |
save_total_limit=2,
|
155 |
)
|
156 |
|
|
|
|
|
157 |
optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=WEIGHT_DECAY)
|
158 |
scheduler = get_cosine_schedule_with_warmup(
|
159 |
optimizer,
|
|
|
209 |
train_model(model, tokenizer, dataset, push_to_hub, is_inst_finetune)
|
210 |
|
211 |
if __name__ == "__main__":
|
212 |
+
main(PUSH_TO_HUB, INSTRUCT_FINETUNE_BOOL)
|
|