Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -8,11 +8,10 @@ from datasets import load_dataset
|
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
9 |
|
10 |
MAX_SEQ_LENGTH = 128
|
11 |
-
BATCH_SIZE =
|
12 |
-
EPOCHS =
|
13 |
LEARNING_RATE = 1e-4
|
14 |
-
|
15 |
-
FACTOR = 1
|
16 |
VOCAB_SIZE = 3200
|
17 |
INPUT_DATASET = "nroggendorff/elephant"
|
18 |
OUTPUT_REPO = "smallama"
|
@@ -94,7 +93,6 @@ def train_model(model, tokenizer, dataset):
|
|
94 |
num_train_epochs=EPOCHS,
|
95 |
per_device_train_batch_size=BATCH_SIZE,
|
96 |
learning_rate=LEARNING_RATE,
|
97 |
-
fp16=FP16,
|
98 |
optim="sgd"
|
99 |
)
|
100 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
|
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
9 |
|
10 |
MAX_SEQ_LENGTH = 128
|
11 |
+
BATCH_SIZE = 16
|
12 |
+
EPOCHS = 8
|
13 |
LEARNING_RATE = 1e-4
|
14 |
+
FACTOR = 1024
|
|
|
15 |
VOCAB_SIZE = 3200
|
16 |
INPUT_DATASET = "nroggendorff/elephant"
|
17 |
OUTPUT_REPO = "smallama"
|
|
|
93 |
num_train_epochs=EPOCHS,
|
94 |
per_device_train_batch_size=BATCH_SIZE,
|
95 |
learning_rate=LEARNING_RATE,
|
|
|
96 |
optim="sgd"
|
97 |
)
|
98 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|