Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -116,6 +116,8 @@ def train_model(model, tokenizer, dataset, push):
|
|
116 |
num_warmup_steps=args.warmup_steps,
|
117 |
num_training_steps=len(dataset) * args.num_train_epochs // args.gradient_accumulation_steps
|
118 |
)
|
|
|
|
|
119 |
|
120 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
121 |
trainer = trl.SFTTrainer(
|
@@ -127,11 +129,7 @@ def train_model(model, tokenizer, dataset, push):
|
|
127 |
max_seq_length=MAX_SEQ_LENGTH,
|
128 |
optimizers=(optimizer, scheduler)
|
129 |
)
|
130 |
-
|
131 |
-
model, optimizer = accelerator.prepare(model, optimizer)
|
132 |
-
trainer.model = model
|
133 |
-
trainer.optimizer = optimizer
|
134 |
-
trainer = accelerator.prepare(trainer)
|
135 |
trainer.train()
|
136 |
|
137 |
trained_model = trainer.model
|
|
|
116 |
num_warmup_steps=args.warmup_steps,
|
117 |
num_training_steps=len(dataset) * args.num_train_epochs // args.gradient_accumulation_steps
|
118 |
)
|
119 |
+
|
120 |
+
model, optimizer = accelerator.prepare(model, optimizer)
|
121 |
|
122 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
123 |
trainer = trl.SFTTrainer(
|
|
|
129 |
max_seq_length=MAX_SEQ_LENGTH,
|
130 |
optimizers=(optimizer, scheduler)
|
131 |
)
|
132 |
+
|
|
|
|
|
|
|
|
|
133 |
trainer.train()
|
134 |
|
135 |
trained_model = trainer.model
|