nroggendorff commited on
Commit
c2fde61
·
verified ·
1 Parent(s): ef1c7ad

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +10 -10
train.py CHANGED
@@ -7,14 +7,14 @@ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingA
7
  from datasets import load_dataset
8
  from tokenizers import ByteLevelBPETokenizer
9
 
10
- MAX_SEQ_LENGTH = 128
11
  BATCH_SIZE = 512
12
  EPOCHS = 50
13
  LEARNING_RATE = 1e-5
14
  FACTOR = 2
15
  VOCAB_SIZE = 3200
16
- INPUT_DATASET = "nroggendorff/elephant"
17
- OUTPUT_REPO = "smallama"
18
  PUSH_TO_HUB = True
19
 
20
  def load_data():
@@ -27,7 +27,7 @@ def create_tokenizer(training_corpus):
27
  training_corpus,
28
  vocab_size=VOCAB_SIZE,
29
  min_frequency=2,
30
- special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>", "<|user|>", "<|bot|>", "<|end|>"]
31
  )
32
 
33
  fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
@@ -78,15 +78,15 @@ def configure_tokenizer(tokenizer):
78
  "unk_token": "<unk>",
79
  "pad_token": "<pad>",
80
  "mask_token": "<mask>",
81
- "additional_special_tokens": ["<|user|>", "<|bot|>", "<|end|>"]
82
  }
83
  tokenizer.add_special_tokens(special_tokens)
84
 
85
- tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
86
- tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
87
 
88
- chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
89
- tokenizer.chat_template = chat_template
90
 
91
  def train_model(model, tokenizer, dataset, push):
92
  args = TrainingArguments(
@@ -96,7 +96,7 @@ def train_model(model, tokenizer, dataset, push):
96
  learning_rate=LEARNING_RATE,
97
  optim="sgd"
98
  )
99
- dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
100
  trainer = trl.SFTTrainer(
101
  model=model,
102
  tokenizer=tokenizer,
 
7
  from datasets import load_dataset
8
  from tokenizers import ByteLevelBPETokenizer
9
 
10
+ MAX_SEQ_LENGTH = 512
11
  BATCH_SIZE = 512
12
  EPOCHS = 50
13
  LEARNING_RATE = 1e-5
14
  FACTOR = 2
15
  VOCAB_SIZE = 3200
16
+ INPUT_DATASET = "nroggendorff/godson"
17
+ OUTPUT_REPO = "sson"
18
  PUSH_TO_HUB = True
19
 
20
  def load_data():
 
27
  training_corpus,
28
  vocab_size=VOCAB_SIZE,
29
  min_frequency=2,
30
+ special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"]#, "<|user|>", "<|bot|>", "<|end|>"]
31
  )
32
 
33
  fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
 
78
  "unk_token": "<unk>",
79
  "pad_token": "<pad>",
80
  "mask_token": "<mask>",
81
+ #"additional_special_tokens": ["<|user|>", "<|bot|>", "<|end|>"]
82
  }
83
  tokenizer.add_special_tokens(special_tokens)
84
 
85
+ #tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
86
+ #tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
87
 
88
+ #chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
89
+ #tokenizer.chat_template = chat_template
90
 
91
  def train_model(model, tokenizer, dataset, push):
92
  args = TrainingArguments(
 
96
  learning_rate=LEARNING_RATE,
97
  optim="sgd"
98
  )
99
+ #dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
100
  trainer = trl.SFTTrainer(
101
  model=model,
102
  tokenizer=tokenizer,