Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -9,7 +9,7 @@ from tokenizers import ByteLevelBPETokenizer
|
|
9 |
|
10 |
MAX_SEQ_LENGTH = 512
|
11 |
BATCH_SIZE = 192
|
12 |
-
EPOCHS =
|
13 |
LEARNING_RATE = 2e-2
|
14 |
FACTOR = 64
|
15 |
VOCAB_SIZE = 32000
|
@@ -125,15 +125,16 @@ def train_model(model, tokenizer, dataset, push):
|
|
125 |
optimizers=(optimizer, scheduler)
|
126 |
)
|
127 |
|
128 |
-
trainer.train()
|
129 |
|
130 |
trained_model = trainer.model
|
131 |
trained_tokenizer = trainer.tokenizer
|
132 |
|
133 |
if push:
|
134 |
repo_id = OUTPUT_REPO
|
135 |
-
|
136 |
-
|
|
|
137 |
else:
|
138 |
trained_model.save_pretrained("model")
|
139 |
trained_tokenizer.save_pretrained("tokenizer")
|
|
|
9 |
|
10 |
MAX_SEQ_LENGTH = 512
|
11 |
BATCH_SIZE = 192
|
12 |
+
EPOCHS = 30
|
13 |
LEARNING_RATE = 2e-2
|
14 |
FACTOR = 64
|
15 |
VOCAB_SIZE = 32000
|
|
|
125 |
optimizers=(optimizer, scheduler)
|
126 |
)
|
127 |
|
128 |
+
train = trainer.train()
|
129 |
|
130 |
trained_model = trainer.model
|
131 |
trained_tokenizer = trainer.tokenizer
|
132 |
|
133 |
if push:
|
134 |
repo_id = OUTPUT_REPO
|
135 |
+
msg = str(train.training_loss)
|
136 |
+
trained_model.push_to_hub(repo_id, commit_message=msg)
|
137 |
+
trained_tokenizer.push_to_hub(repo_id, commit_message=msg)
|
138 |
else:
|
139 |
trained_model.save_pretrained("model")
|
140 |
trained_tokenizer.save_pretrained("tokenizer")
|