nroggendorff commited on
Commit
31cd6e7
·
verified ·
1 Parent(s): 585e890

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +5 -4
train.py CHANGED
@@ -9,7 +9,7 @@ from tokenizers import ByteLevelBPETokenizer
9
 
10
  MAX_SEQ_LENGTH = 512
11
  BATCH_SIZE = 192
12
- EPOCHS = 3
13
  LEARNING_RATE = 2e-2
14
  FACTOR = 64
15
  VOCAB_SIZE = 32000
@@ -125,15 +125,16 @@ def train_model(model, tokenizer, dataset, push):
125
  optimizers=(optimizer, scheduler)
126
  )
127
 
128
- trainer.train()
129
 
130
  trained_model = trainer.model
131
  trained_tokenizer = trainer.tokenizer
132
 
133
  if push:
134
  repo_id = OUTPUT_REPO
135
- trained_model.push_to_hub(repo_id)
136
- trained_tokenizer.push_to_hub(repo_id)
 
137
  else:
138
  trained_model.save_pretrained("model")
139
  trained_tokenizer.save_pretrained("tokenizer")
 
9
 
10
  MAX_SEQ_LENGTH = 512
11
  BATCH_SIZE = 192
12
+ EPOCHS = 30
13
  LEARNING_RATE = 2e-2
14
  FACTOR = 64
15
  VOCAB_SIZE = 32000
 
125
  optimizers=(optimizer, scheduler)
126
  )
127
 
128
+ train = trainer.train()
129
 
130
  trained_model = trainer.model
131
  trained_tokenizer = trainer.tokenizer
132
 
133
  if push:
134
  repo_id = OUTPUT_REPO
135
+ msg = str(train.training_loss)
136
+ trained_model.push_to_hub(repo_id, commit_message=msg)
137
+ trained_tokenizer.push_to_hub(repo_id, commit_message=msg)
138
  else:
139
  trained_model.save_pretrained("model")
140
  trained_tokenizer.save_pretrained("tokenizer")