nroggendorff commited on
Commit
d2ce25e
·
verified ·
1 Parent(s): f681719

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -5
train.py CHANGED
@@ -116,6 +116,8 @@ def train_model(model, tokenizer, dataset, push):
116
  num_warmup_steps=args.warmup_steps,
117
  num_training_steps=len(dataset) * args.num_train_epochs // args.gradient_accumulation_steps
118
  )
 
 
119
 
120
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
121
  trainer = trl.SFTTrainer(
@@ -127,11 +129,7 @@ def train_model(model, tokenizer, dataset, push):
127
  max_seq_length=MAX_SEQ_LENGTH,
128
  optimizers=(optimizer, scheduler)
129
  )
130
-
131
- model, optimizer = accelerator.prepare(model, optimizer)
132
- trainer.model = model
133
- trainer.optimizer = optimizer
134
- trainer = accelerator.prepare(trainer)
135
  trainer.train()
136
 
137
  trained_model = trainer.model
 
116
  num_warmup_steps=args.warmup_steps,
117
  num_training_steps=len(dataset) * args.num_train_epochs // args.gradient_accumulation_steps
118
  )
119
+
120
+ model, optimizer = accelerator.prepare(model, optimizer)
121
 
122
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
123
  trainer = trl.SFTTrainer(
 
129
  max_seq_length=MAX_SEQ_LENGTH,
130
  optimizers=(optimizer, scheduler)
131
  )
132
+
 
 
 
 
133
  trainer.train()
134
 
135
  trained_model = trainer.model