nroggendorff commited on
Commit
4aafa13
·
verified ·
1 Parent(s): 6008f38

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +9 -2
train.py CHANGED
@@ -44,6 +44,10 @@ def create_tokenizer(training_corpus):
44
  fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
45
  return fast_tokenizer
46
 
 
 
 
 
47
  def get_training_corpus(dataset):
48
  texts = []
49
  #for field in ['pretrain', 'instruct']:
@@ -163,8 +167,11 @@ def train_model(model, tokenizer, dataset, push, isinst):
163
 
164
  def main(push_to_hub=True, is_inst_finetune=False):
165
  dataset = load_data()
166
- training_corpus = get_training_corpus(dataset)
167
- tokenizer = create_tokenizer(training_corpus)
 
 
 
168
  configure_tokenizer(tokenizer)
169
  if is_inst_finetune:
170
  model = load_model()
 
44
  fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
45
  return fast_tokenizer
46
 
47
+ def load_tokenizer():
48
+ tok = AutoTokenizer.from_pretrained(OUTPUT_REPO)
49
+ return tok
50
+
51
  def get_training_corpus(dataset):
52
  texts = []
53
  #for field in ['pretrain', 'instruct']:
 
167
 
168
  def main(push_to_hub=True, is_inst_finetune=False):
169
  dataset = load_data()
170
+ if not is_inst_finetune:
171
+ training_corpus = get_training_corpus(dataset)
172
+ tokenizer = create_tokenizer(training_corpus)
173
+ else:
174
+ tokenizer = load_tokenizer()
175
  configure_tokenizer(tokenizer)
176
  if is_inst_finetune:
177
  model = load_model()