Vishwas1 commited on
Commit
7ffd02f
·
verified ·
1 Parent(s): bb7dbb8

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +25 -13
train_model.py CHANGED
@@ -1,5 +1,5 @@
1
  # train_model.py (Training Script)
2
- from huggingface_hub import login
3
  import argparse
4
  from transformers import (
5
  GPT2Config,
@@ -15,7 +15,7 @@ from transformers import (
15
  from datasets import load_dataset, Dataset
16
  import torch
17
  import os
18
- from huggingface_hub import HfApi, HfFolder
19
  import logging
20
 
21
  def setup_logging(log_file_path):
@@ -49,7 +49,7 @@ def parse_arguments():
49
  parser.add_argument("--task", type=str, required=True, choices=["generation", "classification"],
50
  help="Task type: 'generation' or 'classification'")
51
  parser.add_argument("--model_name", type=str, required=True, help="Name of the model")
52
- parser.add_argument("--dataset_name", type=str, required=True, help="Name of the Hugging Face dataset (e.g., 'username/dataset')")
53
  parser.add_argument("--num_layers", type=int, default=12, help="Number of hidden layers")
54
  parser.add_argument("--attention_heads", type=int, default=1, help="Number of attention heads")
55
  parser.add_argument("--hidden_size", type=int, default=64, help="Hidden size of the model")
@@ -65,14 +65,21 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
65
  logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
66
  try:
67
  if task == "generation":
68
- train_dataset = load_dataset(dataset_name,split='train',use_auth_token=True)
69
- dataset = train_dataset['train'].shuffle(seed=42).select(range(500))
 
 
 
 
70
  logging.info("Dataset loaded successfully for generation task.")
71
  def tokenize_function(examples):
72
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
73
  elif task == "classification":
74
- train_dataset = load_dataset(dataset_name,split='train',use_auth_token=True)
75
- dataset = train_dataset['train'].shuffle(seed=42).select(range(500))
 
 
 
76
  logging.info("Dataset loaded successfully for classification task.")
77
  # Assuming the dataset has 'text' and 'label' columns
78
  def tokenize_function(examples):
@@ -80,7 +87,8 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
80
  else:
81
  raise ValueError("Unsupported task type")
82
 
83
- tokenized_datasets = dataset.map(tokenize_function, batched=True)
 
84
  logging.info("Dataset tokenization complete.")
85
  return tokenized_datasets
86
  except Exception as e:
@@ -139,18 +147,22 @@ def main():
139
 
140
  # Initialize Hugging Face API
141
  api = HfApi()
142
- hf_token = os.getenv('HF_API_TOKEN')
 
 
143
  if not hf_token:
144
- logging.error("HF_API_TOKEN is not set. Please set it as an environment variable.")
145
- raise ValueError("HF_API_TOKEN is not set.")
146
-
147
- # Initialize tokenizer
148
  try:
149
  login(token=hf_token)
150
  logging.info("Successfully logged in to Hugging Face Hub.")
151
  except Exception as e:
152
  logging.error(f"Failed to log in to Hugging Face Hub: {str(e)}")
153
  raise e
 
 
154
  try:
155
  logging.info("Initializing tokenizer...")
156
  if args.task == "generation":
 
1
  # train_model.py (Training Script)
2
+
3
  import argparse
4
  from transformers import (
5
  GPT2Config,
 
15
  from datasets import load_dataset, Dataset
16
  import torch
17
  import os
18
+ from huggingface_hub import login, HfApi, HfFolder
19
  import logging
20
 
21
  def setup_logging(log_file_path):
 
49
  parser.add_argument("--task", type=str, required=True, choices=["generation", "classification"],
50
  help="Task type: 'generation' or 'classification'")
51
  parser.add_argument("--model_name", type=str, required=True, help="Name of the model")
52
+ parser.add_argument("--dataset_name", type=str, required=True, help="Name of the Hugging Face dataset (e.g., 'wikitext/wikitext-2-raw-v1')")
53
  parser.add_argument("--num_layers", type=int, default=12, help="Number of hidden layers")
54
  parser.add_argument("--attention_heads", type=int, default=1, help="Number of attention heads")
55
  parser.add_argument("--hidden_size", type=int, default=64, help="Hidden size of the model")
 
65
  logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
66
  try:
67
  if task == "generation":
68
+ # Check if dataset_name includes config
69
+ if '/' in dataset_name:
70
+ dataset, config = dataset_name.split('/', 1)
71
+ dataset = load_dataset(dataset, config, split='train[:1%]', use_auth_token=True)
72
+ else:
73
+ dataset = load_dataset(dataset_name, split='train[:1%]', use_auth_token=True)
74
  logging.info("Dataset loaded successfully for generation task.")
75
  def tokenize_function(examples):
76
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
77
  elif task == "classification":
78
+ if '/' in dataset_name:
79
+ dataset, config = dataset_name.split('/', 1)
80
+ dataset = load_dataset(dataset, config, split='train[:1%]', use_auth_token=True)
81
+ else:
82
+ dataset = load_dataset(dataset_name, split='train[:1%]', use_auth_token=True)
83
  logging.info("Dataset loaded successfully for classification task.")
84
  # Assuming the dataset has 'text' and 'label' columns
85
  def tokenize_function(examples):
 
87
  else:
88
  raise ValueError("Unsupported task type")
89
 
90
+ # Shuffle and select a subset
91
+ tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
92
  logging.info("Dataset tokenization complete.")
93
  return tokenized_datasets
94
  except Exception as e:
 
147
 
148
  # Initialize Hugging Face API
149
  api = HfApi()
150
+
151
+ # Retrieve the Hugging Face API token from environment variables
152
+ hf_token = os.getenv("HF_API_TOKEN")
153
  if not hf_token:
154
+ logging.error("HF_API_TOKEN environment variable not set.")
155
+ raise ValueError("HF_API_TOKEN environment variable not set.")
156
+
157
+ # Perform login using the API token
158
  try:
159
  login(token=hf_token)
160
  logging.info("Successfully logged in to Hugging Face Hub.")
161
  except Exception as e:
162
  logging.error(f"Failed to log in to Hugging Face Hub: {str(e)}")
163
  raise e
164
+
165
+ # Initialize tokenizer
166
  try:
167
  logging.info("Initializing tokenizer...")
168
  if args.task == "generation":