Vishwas1 commited on
Commit
93a2c3f
·
verified ·
1 Parent(s): ef223be

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +17 -29
train_model.py CHANGED
@@ -12,10 +12,10 @@ from transformers import (
12
  DataCollatorForLanguageModeling,
13
  DataCollatorWithPadding,
14
  )
15
- from datasets import load_dataset, Dataset
16
  import torch
17
  import os
18
- from huggingface_hub import login, HfApi, HfFolder
19
  import logging
20
 
21
  from torch.optim import AdamW # Import PyTorch's AdamW
@@ -34,10 +34,9 @@ def setup_logging(log_file_path):
34
  f_handler.setLevel(logging.INFO)
35
 
36
  # Create formatters and add to handlers
37
- c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
38
- f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
39
- c_handler.setFormatter(c_format)
40
- f_handler.setFormatter(f_format)
41
 
42
  # Add handlers to the logger
43
  logger.addHandler(c_handler)
@@ -66,30 +65,18 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
66
  """
67
  logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
68
  try:
69
- if task == "generation":
70
- # Check if dataset_name includes a configuration
71
- if '/' in dataset_name:
72
- dataset, config = dataset_name.split('/', 1)
73
- dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
74
- else:
75
- dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
76
- logging.info("Dataset loaded successfully for generation task.")
77
- def tokenize_function(examples):
78
- return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
79
- elif task == "classification":
80
- if '/' in dataset_name:
81
- dataset, config = dataset_name.split('/', 1)
82
- dataset = load_dataset(dataset, config, split='train')
83
- else:
84
- dataset = load_dataset(dataset_name, split='train')
85
- logging.info("Dataset loaded successfully for classification task.")
86
- # Assuming the dataset has 'text' and 'label' columns
87
- def tokenize_function(examples):
88
- return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
89
  else:
90
- raise ValueError("Unsupported task type")
91
 
92
- # Shuffle and select a subset
 
 
 
 
 
93
  tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
94
  logging.info("Dataset tokenization complete.")
95
  return tokenized_datasets
@@ -186,7 +173,7 @@ def main():
186
  logging.info("Setting pad_token to eos_token.")
187
  tokenizer.pad_token = tokenizer.eos_token
188
  logging.info(f"Tokenizer pad_token set to: {tokenizer.pad_token}")
189
- # Resize model's token embeddings after setting pad_token
190
  model = initialize_model(
191
  task=args.task,
192
  model_name=args.model_name,
@@ -315,3 +302,4 @@ def main():
315
 
316
  if __name__ == "__main__":
317
  main()
 
 
12
  DataCollatorForLanguageModeling,
13
  DataCollatorWithPadding,
14
  )
15
+ from datasets import load_dataset
16
  import torch
17
  import os
18
+ from huggingface_hub import login, HfApi
19
  import logging
20
 
21
  from torch.optim import AdamW # Import PyTorch's AdamW
 
34
  f_handler.setLevel(logging.INFO)
35
 
36
  # Create formatters and add to handlers
37
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
38
+ c_handler.setFormatter(formatter)
39
+ f_handler.setFormatter(formatter)
 
40
 
41
  # Add handlers to the logger
42
  logger.addHandler(c_handler)
 
65
  """
66
  logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
67
  try:
68
+ if '/' in dataset_name:
69
+ dataset, config = dataset_name.split('/', 1)
70
+ dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  else:
72
+ dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
73
 
74
+ logging.info("Dataset loaded successfully.")
75
+
76
+ def tokenize_function(examples):
77
+ return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
78
+
79
+ # Tokenize the dataset
80
  tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
81
  logging.info("Dataset tokenization complete.")
82
  return tokenized_datasets
 
173
  logging.info("Setting pad_token to eos_token.")
174
  tokenizer.pad_token = tokenizer.eos_token
175
  logging.info(f"Tokenizer pad_token set to: {tokenizer.pad_token}")
176
+ # Initialize model after setting pad_token
177
  model = initialize_model(
178
  task=args.task,
179
  model_name=args.model_name,
 
302
 
303
  if __name__ == "__main__":
304
  main()
305
+