Vishwas1 commited on
Commit
4a9e5f8
·
verified ·
1 Parent(s): c9f3a0d

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +14 -8
train_model.py CHANGED
@@ -16,7 +16,7 @@ import torch
16
  import os
17
  from huggingface_hub import login, HfApi
18
  import logging
19
- from torch.optim import AdamW # Import PyTorch's AdamW
20
 
21
  def setup_logging(log_file_path):
22
  """
@@ -66,23 +66,28 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
66
  dataset = load_dataset(dataset_name, split='train')
67
  logging.info("Dataset loaded successfully.")
68
 
 
 
 
69
  def tokenize_function(examples):
70
  try:
71
- # Tokenize with truncation, defer padding to DataCollator
72
  tokens = tokenizer(
73
  examples['text'],
74
  truncation=True,
75
- max_length=sequence_length, # Set maximum length
76
- padding=False, # Padding will be handled by the DataCollatorWithPadding
77
- return_tensors=None # Let the DataCollator handle tensor creation
78
  )
 
 
79
  return tokens
80
  except Exception as e:
81
  logging.error(f"Error during tokenization: {e}")
82
- logging.error(f"Example data: {examples}")
83
  raise e
84
 
85
- # Tokenize the dataset using the modified tokenize_function
86
  tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
87
  logging.info("Dataset tokenization complete.")
88
  return tokenized_datasets
@@ -210,7 +215,7 @@ def main():
210
  if args.task == "generation":
211
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
212
  elif args.task == "classification":
213
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # Dynamic padding during batch creation
214
  else:
215
  logging.error("Unsupported task type for data collator.")
216
  raise ValueError("Unsupported task type for data collator.")
@@ -276,3 +281,4 @@ def main():
276
 
277
  if __name__ == "__main__":
278
  main()
 
 
16
  import os
17
  from huggingface_hub import login, HfApi
18
  import logging
19
+ from torch.optim import AdamW
20
 
21
  def setup_logging(log_file_path):
22
  """
 
66
  dataset = load_dataset(dataset_name, split='train')
67
  logging.info("Dataset loaded successfully.")
68
 
69
+ # Log some examples to check dataset structure
70
+ logging.info(f"Example data from the dataset: {dataset[:5]}")
71
+
72
  def tokenize_function(examples):
73
  try:
74
+ # Tokenize with truncation and padding
75
  tokens = tokenizer(
76
  examples['text'],
77
  truncation=True,
78
+ max_length=sequence_length,
79
+ padding='max_length', # Force padding to max length for debugging
80
+ return_tensors=None # Let the collator handle tensor conversion
81
  )
82
+ # Log the tokens for debugging
83
+ logging.info(f"Tokenized example: {tokens}")
84
  return tokens
85
  except Exception as e:
86
  logging.error(f"Error during tokenization: {e}")
87
+ logging.error(f"Problematic example: {examples}")
88
  raise e
89
 
90
+ # Tokenize the dataset
91
  tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
92
  logging.info("Dataset tokenization complete.")
93
  return tokenized_datasets
 
215
  if args.task == "generation":
216
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
217
  elif args.task == "classification":
218
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # Handle padding dynamically during batching
219
  else:
220
  logging.error("Unsupported task type for data collator.")
221
  raise ValueError("Unsupported task type for data collator.")
 
281
 
282
  if __name__ == "__main__":
283
  main()
284
+