Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -6,6 +6,7 @@ import trl
|
|
6 |
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW, get_linear_schedule_with_warmup
|
7 |
from datasets import load_dataset
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
|
|
9 |
|
10 |
MAX_SEQ_LENGTH = 512
|
11 |
BATCH_SIZE = 32
|
@@ -22,6 +23,8 @@ GRADIENT_ACCUMULATION_STEPS = 8
|
|
22 |
CLIPPING = 1.0
|
23 |
PUSH_TO_HUB = True
|
24 |
|
|
|
|
|
25 |
def load_data():
|
26 |
dataset = load_dataset(INPUT_DATASET, split="train")#.select(range(int(2e+4)))
|
27 |
return dataset
|
@@ -124,6 +127,11 @@ def train_model(model, tokenizer, dataset, push):
|
|
124 |
max_seq_length=MAX_SEQ_LENGTH,
|
125 |
optimizers=(optimizer, scheduler)
|
126 |
)
|
|
|
|
|
|
|
|
|
|
|
127 |
trainer.train()
|
128 |
|
129 |
trained_model = trainer.model
|
|
|
6 |
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW, get_linear_schedule_with_warmup
|
7 |
from datasets import load_dataset
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
9 |
+
from accelerate import Accelerator
|
10 |
|
11 |
MAX_SEQ_LENGTH = 512
|
12 |
BATCH_SIZE = 32
|
|
|
23 |
CLIPPING = 1.0
|
24 |
PUSH_TO_HUB = True
|
25 |
|
26 |
+
accelerator = Accelerator()
|
27 |
+
|
28 |
def load_data():
|
29 |
dataset = load_dataset(INPUT_DATASET, split="train")#.select(range(int(2e+4)))
|
30 |
return dataset
|
|
|
127 |
max_seq_length=MAX_SEQ_LENGTH,
|
128 |
optimizers=(optimizer, scheduler)
|
129 |
)
|
130 |
+
|
131 |
+
model, optimizer = accelerator.prepare(model, optimizer)
|
132 |
+
trainer.model = model
|
133 |
+
trainer.optimizer = optimizer
|
134 |
+
trainer = accelerator.prepare(trainer)
|
135 |
trainer.train()
|
136 |
|
137 |
trained_model = trainer.model
|