Vishwas1 commited on
Commit
c9f3a0d
·
verified ·
1 Parent(s): 7ead975

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +31 -10
train_model.py CHANGED
@@ -67,8 +67,20 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
67
  logging.info("Dataset loaded successfully.")
68
 
69
  def tokenize_function(examples):
70
- # Truncate and set max_length, but let DataCollator handle padding
71
- return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Tokenize the dataset using the modified tokenize_function
74
  tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
@@ -182,7 +194,6 @@ def main():
182
  logging.error(f"Error initializing tokenizer or model: {str(e)}")
183
  raise e
184
 
185
- # Load and prepare dataset
186
  # Load and prepare dataset
187
  try:
188
  tokenized_datasets = load_and_prepare_dataset(
@@ -194,26 +205,38 @@ def main():
194
  except Exception as e:
195
  logging.error("Failed to load and prepare dataset.")
196
  raise e
197
-
198
  # Define data collator
199
  if args.task == "generation":
200
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
201
  elif args.task == "classification":
202
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True) # Explicit padding
203
  else:
204
  logging.error("Unsupported task type for data collator.")
205
  raise ValueError("Unsupported task type for data collator.")
206
-
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  # Initialize Trainer with the data collator
208
  trainer = Trainer(
209
  model=model,
210
  args=training_args,
211
  train_dataset=tokenized_datasets,
212
  data_collator=data_collator,
213
- optimizers=(get_optimizer(model, training_args.learning_rate), None) # None for scheduler
214
  )
215
 
216
-
217
  # Start training
218
  logging.info("Starting training...")
219
  try:
@@ -253,5 +276,3 @@ def main():
253
 
254
  if __name__ == "__main__":
255
  main()
256
-
257
-
 
67
  logging.info("Dataset loaded successfully.")
68
 
69
  def tokenize_function(examples):
70
+ try:
71
+ # Tokenize with truncation, defer padding to DataCollator
72
+ tokens = tokenizer(
73
+ examples['text'],
74
+ truncation=True,
75
+ max_length=sequence_length, # Set maximum length
76
+ padding=False, # Padding will be handled by the DataCollatorWithPadding
77
+ return_tensors=None # Let the DataCollator handle tensor creation
78
+ )
79
+ return tokens
80
+ except Exception as e:
81
+ logging.error(f"Error during tokenization: {e}")
82
+ logging.error(f"Example data: {examples}")
83
+ raise e
84
 
85
  # Tokenize the dataset using the modified tokenize_function
86
  tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
 
194
  logging.error(f"Error initializing tokenizer or model: {str(e)}")
195
  raise e
196
 
 
197
  # Load and prepare dataset
198
  try:
199
  tokenized_datasets = load_and_prepare_dataset(
 
205
  except Exception as e:
206
  logging.error("Failed to load and prepare dataset.")
207
  raise e
208
+
209
  # Define data collator
210
  if args.task == "generation":
211
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
212
  elif args.task == "classification":
213
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # Dynamic padding during batch creation
214
  else:
215
  logging.error("Unsupported task type for data collator.")
216
  raise ValueError("Unsupported task type for data collator.")
217
+
218
+ # Define training arguments
219
+ training_args = TrainingArguments(
220
+ output_dir=f"./models/{args.model_name}",
221
+ num_train_epochs=3,
222
+ per_device_train_batch_size=8 if args.task == "generation" else 16,
223
+ save_steps=5000,
224
+ save_total_limit=2,
225
+ logging_steps=500,
226
+ learning_rate=5e-4 if args.task == "generation" else 5e-5,
227
+ remove_unused_columns=False,
228
+ push_to_hub=False
229
+ )
230
+
231
  # Initialize Trainer with the data collator
232
  trainer = Trainer(
233
  model=model,
234
  args=training_args,
235
  train_dataset=tokenized_datasets,
236
  data_collator=data_collator,
237
+ optimizers=(get_optimizer(model, training_args.learning_rate), None)
238
  )
239
 
 
240
  # Start training
241
  logging.info("Starting training...")
242
  try:
 
276
 
277
  if __name__ == "__main__":
278
  main()