Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -44,6 +44,10 @@ def create_tokenizer(training_corpus):
|
|
44 |
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
|
45 |
return fast_tokenizer
|
46 |
|
|
|
|
|
|
|
|
|
47 |
def get_training_corpus(dataset):
|
48 |
texts = []
|
49 |
#for field in ['pretrain', 'instruct']:
|
@@ -163,8 +167,11 @@ def train_model(model, tokenizer, dataset, push, isinst):
|
|
163 |
|
164 |
def main(push_to_hub=True, is_inst_finetune=False):
|
165 |
dataset = load_data()
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
168 |
configure_tokenizer(tokenizer)
|
169 |
if is_inst_finetune:
|
170 |
model = load_model()
|
|
|
44 |
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
|
45 |
return fast_tokenizer
|
46 |
|
47 |
+
def load_tokenizer():
|
48 |
+
tok = AutoTokenizer.from_pretrained(OUTPUT_REPO)
|
49 |
+
return tok
|
50 |
+
|
51 |
def get_training_corpus(dataset):
|
52 |
texts = []
|
53 |
#for field in ['pretrain', 'instruct']:
|
|
|
167 |
|
168 |
def main(push_to_hub=True, is_inst_finetune=False):
|
169 |
dataset = load_data()
|
170 |
+
if not is_inst_finetune:
|
171 |
+
training_corpus = get_training_corpus(dataset)
|
172 |
+
tokenizer = create_tokenizer(training_corpus)
|
173 |
+
else:
|
174 |
+
tokenizer = load_tokenizer()
|
175 |
configure_tokenizer(tokenizer)
|
176 |
if is_inst_finetune:
|
177 |
model = load_model()
|