Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -8,12 +8,12 @@ from datasets import load_dataset
|
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
9 |
|
10 |
MAX_SEQ_LENGTH = 512
|
11 |
-
BATCH_SIZE =
|
12 |
-
EPOCHS =
|
13 |
LEARNING_RATE = 1e-5
|
14 |
FACTOR = 2
|
15 |
VOCAB_SIZE = 3200
|
16 |
-
INPUT_DATASET = "nroggendorff/
|
17 |
OUTPUT_REPO = "sson"
|
18 |
PUSH_TO_HUB = True
|
19 |
|
@@ -27,7 +27,7 @@ def create_tokenizer(training_corpus):
|
|
27 |
training_corpus,
|
28 |
vocab_size=VOCAB_SIZE,
|
29 |
min_frequency=2,
|
30 |
-
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"
|
31 |
)
|
32 |
|
33 |
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
|
@@ -78,15 +78,15 @@ def configure_tokenizer(tokenizer):
|
|
78 |
"unk_token": "<unk>",
|
79 |
"pad_token": "<pad>",
|
80 |
"mask_token": "<mask>",
|
81 |
-
|
82 |
}
|
83 |
tokenizer.add_special_tokens(special_tokens)
|
84 |
|
85 |
-
|
86 |
-
|
87 |
|
88 |
-
|
89 |
-
|
90 |
|
91 |
def train_model(model, tokenizer, dataset, push):
|
92 |
args = TrainingArguments(
|
@@ -96,7 +96,7 @@ def train_model(model, tokenizer, dataset, push):
|
|
96 |
learning_rate=LEARNING_RATE,
|
97 |
optim="sgd"
|
98 |
)
|
99 |
-
|
100 |
trainer = trl.SFTTrainer(
|
101 |
model=model,
|
102 |
tokenizer=tokenizer,
|
|
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
9 |
|
10 |
MAX_SEQ_LENGTH = 512
|
11 |
+
BATCH_SIZE = 2048
|
12 |
+
EPOCHS = 16
|
13 |
LEARNING_RATE = 1e-5
|
14 |
FACTOR = 2
|
15 |
VOCAB_SIZE = 3200
|
16 |
+
INPUT_DATASET = "nroggendorff/oak"
|
17 |
OUTPUT_REPO = "sson"
|
18 |
PUSH_TO_HUB = True
|
19 |
|
|
|
27 |
training_corpus,
|
28 |
vocab_size=VOCAB_SIZE,
|
29 |
min_frequency=2,
|
30 |
+
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>", "<|user|>", "<|bot|>", "<|end|>"]
|
31 |
)
|
32 |
|
33 |
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
|
|
|
78 |
"unk_token": "<unk>",
|
79 |
"pad_token": "<pad>",
|
80 |
"mask_token": "<mask>",
|
81 |
+
"additional_special_tokens": ["<|user|>", "<|bot|>", "<|end|>"]
|
82 |
}
|
83 |
tokenizer.add_special_tokens(special_tokens)
|
84 |
|
85 |
+
tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
|
86 |
+
tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
|
87 |
|
88 |
+
chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
|
89 |
+
tokenizer.chat_template = chat_template
|
90 |
|
91 |
def train_model(model, tokenizer, dataset, push):
|
92 |
args = TrainingArguments(
|
|
|
96 |
learning_rate=LEARNING_RATE,
|
97 |
optim="sgd"
|
98 |
)
|
99 |
+
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
100 |
trainer = trl.SFTTrainer(
|
101 |
model=model,
|
102 |
tokenizer=tokenizer,
|