nroggendorff commited on
Commit
3ce5976
·
verified ·
1 Parent(s): be2f96f

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -4
train.py CHANGED
@@ -23,8 +23,6 @@ GRADIENT_ACCUMULATION_STEPS = 8
23
  CLIPPING = 1.0
24
  PUSH_TO_HUB = True
25
 
26
- accelerator = Accelerator()
27
-
28
  def load_data():
29
  dataset = load_dataset(INPUT_DATASET, split="train")#.select(range(int(2e+4)))
30
  return dataset
@@ -96,7 +94,7 @@ def configure_tokenizer(tokenizer):
96
  chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
97
  tokenizer.chat_template = chat_template
98
 
99
- def train_model(model, tokenizer, dataset, push):
100
  args = TrainingArguments(
101
  output_dir="model",
102
  num_train_epochs=EPOCHS,
@@ -148,7 +146,8 @@ def main(push_to_hub=True):
148
  tokenizer = create_tokenizer(training_corpus)
149
  configure_tokenizer(tokenizer)
150
  model = create_model(tokenizer)
151
- train_model(model, tokenizer, dataset, push_to_hub)
 
152
 
153
  if __name__ == "__main__":
154
  main(PUSH_TO_HUB)
 
23
  CLIPPING = 1.0
24
  PUSH_TO_HUB = True
25
 
 
 
26
  def load_data():
27
  dataset = load_dataset(INPUT_DATASET, split="train")#.select(range(int(2e+4)))
28
  return dataset
 
94
  chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
95
  tokenizer.chat_template = chat_template
96
 
97
+ def train_model(accelerator, model, tokenizer, dataset, push):
98
  args = TrainingArguments(
99
  output_dir="model",
100
  num_train_epochs=EPOCHS,
 
146
  tokenizer = create_tokenizer(training_corpus)
147
  configure_tokenizer(tokenizer)
148
  model = create_model(tokenizer)
149
+ accelerator = Accelerator()
150
+ train_model(accelerator, model, tokenizer, dataset, push_to_hub)
151
 
152
  if __name__ == "__main__":
153
  main(PUSH_TO_HUB)