Spaces:
Sleeping
Sleeping
Update train_model.py
Browse files- train_model.py +7 -18
train_model.py
CHANGED
@@ -182,6 +182,7 @@ def main():
|
|
182 |
logging.error(f"Error initializing tokenizer or model: {str(e)}")
|
183 |
raise e
|
184 |
|
|
|
185 |
# Load and prepare dataset
|
186 |
try:
|
187 |
tokenized_datasets = load_and_prepare_dataset(
|
@@ -193,38 +194,26 @@ def main():
|
|
193 |
except Exception as e:
|
194 |
logging.error("Failed to load and prepare dataset.")
|
195 |
raise e
|
196 |
-
|
197 |
# Define data collator
|
198 |
if args.task == "generation":
|
199 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
200 |
elif args.task == "classification":
|
201 |
-
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
202 |
else:
|
203 |
logging.error("Unsupported task type for data collator.")
|
204 |
raise ValueError("Unsupported task type for data collator.")
|
205 |
-
|
206 |
-
#
|
207 |
-
training_args = TrainingArguments(
|
208 |
-
output_dir=f"./models/{args.model_name}",
|
209 |
-
num_train_epochs=3,
|
210 |
-
per_device_train_batch_size=8 if args.task == "generation" else 16,
|
211 |
-
save_steps=5000,
|
212 |
-
save_total_limit=2,
|
213 |
-
logging_steps=500,
|
214 |
-
learning_rate=5e-4 if args.task == "generation" else 5e-5,
|
215 |
-
remove_unused_columns=False,
|
216 |
-
push_to_hub=False
|
217 |
-
)
|
218 |
-
|
219 |
-
# Initialize Trainer with PyTorch's AdamW optimizer
|
220 |
trainer = Trainer(
|
221 |
model=model,
|
222 |
args=training_args,
|
223 |
train_dataset=tokenized_datasets,
|
224 |
data_collator=data_collator,
|
225 |
-
optimizers=(get_optimizer(model, training_args.learning_rate), None)
|
226 |
)
|
227 |
|
|
|
228 |
# Start training
|
229 |
logging.info("Starting training...")
|
230 |
try:
|
|
|
182 |
logging.error(f"Error initializing tokenizer or model: {str(e)}")
|
183 |
raise e
|
184 |
|
185 |
+
# Load and prepare dataset
|
186 |
# Load and prepare dataset
|
187 |
try:
|
188 |
tokenized_datasets = load_and_prepare_dataset(
|
|
|
194 |
except Exception as e:
|
195 |
logging.error("Failed to load and prepare dataset.")
|
196 |
raise e
|
197 |
+
|
198 |
# Define data collator
|
199 |
if args.task == "generation":
|
200 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
201 |
elif args.task == "classification":
|
202 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True) # Explicit padding
|
203 |
else:
|
204 |
logging.error("Unsupported task type for data collator.")
|
205 |
raise ValueError("Unsupported task type for data collator.")
|
206 |
+
|
207 |
+
# Initialize Trainer with the data collator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
trainer = Trainer(
|
209 |
model=model,
|
210 |
args=training_args,
|
211 |
train_dataset=tokenized_datasets,
|
212 |
data_collator=data_collator,
|
213 |
+
optimizers=(get_optimizer(model, training_args.learning_rate), None) # None for scheduler
|
214 |
)
|
215 |
|
216 |
+
|
217 |
# Start training
|
218 |
logging.info("Starting training...")
|
219 |
try:
|