Vishwas1 commited on
Commit
7ead975
·
verified ·
1 Parent(s): a2a02fa

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +7 -18
train_model.py CHANGED
@@ -182,6 +182,7 @@ def main():
182
  logging.error(f"Error initializing tokenizer or model: {str(e)}")
183
  raise e
184
 
 
185
  # Load and prepare dataset
186
  try:
187
  tokenized_datasets = load_and_prepare_dataset(
@@ -193,38 +194,26 @@ def main():
193
  except Exception as e:
194
  logging.error("Failed to load and prepare dataset.")
195
  raise e
196
-
197
  # Define data collator
198
  if args.task == "generation":
199
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
200
  elif args.task == "classification":
201
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
202
  else:
203
  logging.error("Unsupported task type for data collator.")
204
  raise ValueError("Unsupported task type for data collator.")
205
-
206
- # Define training arguments
207
- training_args = TrainingArguments(
208
- output_dir=f"./models/{args.model_name}",
209
- num_train_epochs=3,
210
- per_device_train_batch_size=8 if args.task == "generation" else 16,
211
- save_steps=5000,
212
- save_total_limit=2,
213
- logging_steps=500,
214
- learning_rate=5e-4 if args.task == "generation" else 5e-5,
215
- remove_unused_columns=False,
216
- push_to_hub=False
217
- )
218
-
219
- # Initialize Trainer with PyTorch's AdamW optimizer
220
  trainer = Trainer(
221
  model=model,
222
  args=training_args,
223
  train_dataset=tokenized_datasets,
224
  data_collator=data_collator,
225
- optimizers=(get_optimizer(model, training_args.learning_rate), None)
226
  )
227
 
 
228
  # Start training
229
  logging.info("Starting training...")
230
  try:
 
182
  logging.error(f"Error initializing tokenizer or model: {str(e)}")
183
  raise e
184
 
185
+ # Load and prepare dataset
186
  # Load and prepare dataset
187
  try:
188
  tokenized_datasets = load_and_prepare_dataset(
 
194
  except Exception as e:
195
  logging.error("Failed to load and prepare dataset.")
196
  raise e
197
+
198
  # Define data collator
199
  if args.task == "generation":
200
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
201
  elif args.task == "classification":
202
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True) # Explicit padding
203
  else:
204
  logging.error("Unsupported task type for data collator.")
205
  raise ValueError("Unsupported task type for data collator.")
206
+
207
+ # Initialize Trainer with the data collator
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  trainer = Trainer(
209
  model=model,
210
  args=training_args,
211
  train_dataset=tokenized_datasets,
212
  data_collator=data_collator,
213
+ optimizers=(get_optimizer(model, training_args.learning_rate), None) # None for scheduler
214
  )
215
 
216
+
217
  # Start training
218
  logging.info("Starting training...")
219
  try: