Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -7,14 +7,14 @@ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingA
|
|
7 |
from datasets import load_dataset
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
9 |
|
10 |
-
MAX_SEQ_LENGTH =
|
11 |
BATCH_SIZE = 512
|
12 |
EPOCHS = 50
|
13 |
LEARNING_RATE = 1e-5
|
14 |
FACTOR = 2
|
15 |
VOCAB_SIZE = 3200
|
16 |
-
INPUT_DATASET = "nroggendorff/
|
17 |
-
OUTPUT_REPO = "
|
18 |
PUSH_TO_HUB = True
|
19 |
|
20 |
def load_data():
|
@@ -27,7 +27,7 @@ def create_tokenizer(training_corpus):
|
|
27 |
training_corpus,
|
28 |
vocab_size=VOCAB_SIZE,
|
29 |
min_frequency=2,
|
30 |
-
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"
|
31 |
)
|
32 |
|
33 |
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
|
@@ -78,15 +78,15 @@ def configure_tokenizer(tokenizer):
|
|
78 |
"unk_token": "<unk>",
|
79 |
"pad_token": "<pad>",
|
80 |
"mask_token": "<mask>",
|
81 |
-
"additional_special_tokens": ["<|user|>", "<|bot|>", "<|end|>"]
|
82 |
}
|
83 |
tokenizer.add_special_tokens(special_tokens)
|
84 |
|
85 |
-
tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
|
86 |
-
tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
|
87 |
|
88 |
-
chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
|
89 |
-
tokenizer.chat_template = chat_template
|
90 |
|
91 |
def train_model(model, tokenizer, dataset, push):
|
92 |
args = TrainingArguments(
|
@@ -96,7 +96,7 @@ def train_model(model, tokenizer, dataset, push):
|
|
96 |
learning_rate=LEARNING_RATE,
|
97 |
optim="sgd"
|
98 |
)
|
99 |
-
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
100 |
trainer = trl.SFTTrainer(
|
101 |
model=model,
|
102 |
tokenizer=tokenizer,
|
|
|
7 |
from datasets import load_dataset
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
9 |
|
10 |
+
MAX_SEQ_LENGTH = 512
|
11 |
BATCH_SIZE = 512
|
12 |
EPOCHS = 50
|
13 |
LEARNING_RATE = 1e-5
|
14 |
FACTOR = 2
|
15 |
VOCAB_SIZE = 3200
|
16 |
+
INPUT_DATASET = "nroggendorff/godson"
|
17 |
+
OUTPUT_REPO = "sson"
|
18 |
PUSH_TO_HUB = True
|
19 |
|
20 |
def load_data():
|
|
|
27 |
training_corpus,
|
28 |
vocab_size=VOCAB_SIZE,
|
29 |
min_frequency=2,
|
30 |
+
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"]#, "<|user|>", "<|bot|>", "<|end|>"]
|
31 |
)
|
32 |
|
33 |
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
|
|
|
78 |
"unk_token": "<unk>",
|
79 |
"pad_token": "<pad>",
|
80 |
"mask_token": "<mask>",
|
81 |
+
#"additional_special_tokens": ["<|user|>", "<|bot|>", "<|end|>"]
|
82 |
}
|
83 |
tokenizer.add_special_tokens(special_tokens)
|
84 |
|
85 |
+
#tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
|
86 |
+
#tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
|
87 |
|
88 |
+
#chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
|
89 |
+
#tokenizer.chat_template = chat_template
|
90 |
|
91 |
def train_model(model, tokenizer, dataset, push):
|
92 |
args = TrainingArguments(
|
|
|
96 |
learning_rate=LEARNING_RATE,
|
97 |
optim="sgd"
|
98 |
)
|
99 |
+
#dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
100 |
trainer = trl.SFTTrainer(
|
101 |
model=model,
|
102 |
tokenizer=tokenizer,
|