nroggendorff commited on
Commit
4b91394
·
verified ·
1 Parent(s): 7e8e146

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +13 -4
train.py CHANGED
@@ -47,9 +47,17 @@ def create_tokenizer(training_corpus):
47
  fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
48
  return fast_tokenizer
49
 
50
- def load_tokenizer():
51
- tok = AutoTokenizer.from_pretrained(OUTPUT_REPO)
52
- return tok
 
 
 
 
 
 
 
 
53
 
54
  def get_training_corpus(dataset):
55
  texts = []
@@ -175,7 +183,8 @@ def main(push_to_hub=True, is_inst_finetune=False):
175
  training_corpus = get_training_corpus(dataset)
176
  tokenizer = create_tokenizer(training_corpus)
177
  else:
178
- tokenizer = load_tokenizer()
 
179
  configure_tokenizer(tokenizer)
180
  if is_inst_finetune:
181
  model = load_model()
 
47
  fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
48
  return fast_tokenizer
49
 
50
+ def load_tokenizer(training_corpus):
51
+ tokenizer = AutoTokenizer.from_pretrained(OUTPUT_REPO)
52
+ special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
53
+ special_tokens.append("<|user|>", "<|bot|>", "<|end|>")
54
+ tokenizer.train_from_iterator(
55
+ training_corpus,
56
+ vocab_size=VOCAB_SIZE,
57
+ min_frequency=2,
58
+ special_tokens=special_tokens
59
+ )
60
+ return tokenizer
61
 
62
  def get_training_corpus(dataset):
63
  texts = []
 
183
  training_corpus = get_training_corpus(dataset)
184
  tokenizer = create_tokenizer(training_corpus)
185
  else:
186
+ training_corpus = get_training_corpus(dataset)
187
+ tokenizer = load_tokenizer(training_corpus)
188
  configure_tokenizer(tokenizer)
189
  if is_inst_finetune:
190
  model = load_model()