nroggendorff commited on
Commit
1f2defa
·
verified ·
1 Parent(s): faf4f27

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +13 -9
train.py CHANGED
@@ -15,6 +15,7 @@ FACTOR = 128
15
  VOCAB_SIZE = 3200
16
  INPUT_DATASET = "nroggendorff/elephant"
17
  OUTPUT_REPO = "smallama"
 
18
 
19
  def load_data():
20
  dataset = load_dataset(INPUT_DATASET, split="train")
@@ -55,8 +56,8 @@ def create_model(tokenizer):
55
  vocab_size=tokenizer.vocab_size,
56
  hidden_size=FACTOR,
57
  intermediate_size=FACTOR * 4,
58
- num_hidden_layers=FACTOR // 32,
59
- num_attention_heads=FACTOR // 64,
60
  max_position_embeddings=MAX_SEQ_LENGTH,
61
  rms_norm_eps=1e-6,
62
  initializer_range=0.02,
@@ -87,7 +88,7 @@ def configure_tokenizer(tokenizer):
87
  chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
88
  tokenizer.chat_template = chat_template
89
 
90
- def train_model(model, tokenizer, dataset):
91
  args = TrainingArguments(
92
  output_dir="model",
93
  num_train_epochs=EPOCHS,
@@ -109,18 +110,21 @@ def train_model(model, tokenizer, dataset):
109
  trained_model = trainer.model
110
  trained_tokenizer = trainer.tokenizer
111
 
112
- repo_id = OUTPUT_REPO
113
- trained_model.push_to_hub(repo_id)
114
- trained_tokenizer.push_to_hub(repo_id)
 
 
 
115
 
116
- def main():
117
  dataset = load_data()
118
  training_corpus = get_training_corpus(dataset)
119
  tokenizer = create_tokenizer(training_corpus)
120
  configure_tokenizer(tokenizer)
121
  model = create_model(tokenizer)
122
- train_model(model, tokenizer, dataset)
123
 
124
  if __name__ == "__main__":
125
- main()
126
  raise RuntimeError("The script is finished.")
 
15
  VOCAB_SIZE = 3200
16
  INPUT_DATASET = "nroggendorff/elephant"
17
  OUTPUT_REPO = "smallama"
18
+ PUSH_TO_HUB = True
19
 
20
  def load_data():
21
  dataset = load_dataset(INPUT_DATASET, split="train")
 
56
  vocab_size=tokenizer.vocab_size,
57
  hidden_size=FACTOR,
58
  intermediate_size=FACTOR * 4,
59
+ num_hidden_layers=max(1, FACTOR // 32),
60
+ num_attention_heads=max(1, FACTOR // 64),
61
  max_position_embeddings=MAX_SEQ_LENGTH,
62
  rms_norm_eps=1e-6,
63
  initializer_range=0.02,
 
88
  chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
89
  tokenizer.chat_template = chat_template
90
 
91
+ def train_model(model, tokenizer, dataset, push):
92
  args = TrainingArguments(
93
  output_dir="model",
94
  num_train_epochs=EPOCHS,
 
110
  trained_model = trainer.model
111
  trained_tokenizer = trainer.tokenizer
112
 
113
+ if push:
114
+ repo_id = OUTPUT_REPO
115
+ trained_model.push_to_hub(repo_id)
116
+ trained_tokenizer.push_to_hub(repo_id)
117
+ else:
118
+ trained_tokenizer.save_pretrained("tokenizer")
119
 
120
+ def main(push_to_hub=True):
121
  dataset = load_data()
122
  training_corpus = get_training_corpus(dataset)
123
  tokenizer = create_tokenizer(training_corpus)
124
  configure_tokenizer(tokenizer)
125
  model = create_model(tokenizer)
126
+ train_model(model, tokenizer, dataset, push_to_hub)
127
 
128
  if __name__ == "__main__":
129
+ main(PUSH_TO_HUB)
130
  raise RuntimeError("The script is finished.")