Vishwas1 commited on
Commit
2ac79ea
·
verified ·
1 Parent(s): 2de0e9b

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +14 -21
train_model.py CHANGED
@@ -14,18 +14,20 @@ from huggingface_hub import HfApi, HfFolder
14
  import logging
15
 
16
  def main():
17
- # Configure Logging
18
- logging.basicConfig(
19
- filename='training.log',
20
- filemode='a',
21
- format='%(asctime)s - %(levelname)s - %(message)s',
22
- level=logging.INFO
23
- )
 
 
24
 
25
  parser = argparse.ArgumentParser()
26
  parser.add_argument("--task", type=str, required=True, help="Task type: generation or classification")
27
  parser.add_argument("--model_name", type=str, required=True, help="Name of the model")
28
- parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset")
29
  parser.add_argument("--num_layers", type=int, default=12)
30
  parser.add_argument("--attention_heads", type=int, default=1)
31
  parser.add_argument("--hidden_size", type=int, default=64)
@@ -53,26 +55,17 @@ def main():
53
 
54
  # Load and prepare dataset
55
  if args.task == "generation":
56
- dataset = load_dataset('text', data_files={'train': args.dataset})
57
  def tokenize_function(examples):
58
  return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
59
  elif args.task == "classification":
60
- # For classification, assume the dataset is a simple text file with "text\tlabel" per line
61
- with open(args.dataset, "r", encoding="utf-8") as f:
62
- lines = f.readlines()
63
- texts = []
64
- labels = []
65
- for line in lines:
66
- parts = line.strip().split("\t")
67
- if len(parts) == 2:
68
- texts.append(parts[0])
69
- labels.append(int(parts[1]))
70
- dataset = Dataset.from_dict({"text": texts, "label": labels})
71
  def tokenize_function(examples):
72
  return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
73
  else:
74
  raise ValueError("Unsupported task type")
75
-
76
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
77
 
78
  if args.task == "generation":
 
14
  import logging
15
 
16
  def main():
17
+ # ... existing code ...
18
+ if args.task == "generation":
19
+ dataset = load_dataset(args.dataset_name, split='train') # Load dataset by name
20
+ elif args.task == "classification":
21
+ dataset = load_dataset(args.dataset_name, split='train') # Adjust if necessary
22
+ else:
23
+ raise ValueError("Unsupported task type")
24
+ # ... existing code ...
25
+
26
 
27
  parser = argparse.ArgumentParser()
28
  parser.add_argument("--task", type=str, required=True, help="Task type: generation or classification")
29
  parser.add_argument("--model_name", type=str, required=True, help="Name of the model")
30
+ parser.add_argument("--dataset_name", type=str, required=True, help="Name of the Hugging Face dataset")
31
  parser.add_argument("--num_layers", type=int, default=12)
32
  parser.add_argument("--attention_heads", type=int, default=1)
33
  parser.add_argument("--hidden_size", type=int, default=64)
 
55
 
56
  # Load and prepare dataset
57
  if args.task == "generation":
58
+ dataset = load_dataset(args.dataset_name, split='train')
59
  def tokenize_function(examples):
60
  return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
61
  elif args.task == "classification":
62
+ dataset = load_dataset(args.dataset_name, split='train')
63
+ # Assuming the dataset has 'text' and 'label' columns
 
 
 
 
 
 
 
 
 
64
  def tokenize_function(examples):
65
  return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
66
  else:
67
  raise ValueError("Unsupported task type")
68
+
69
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
70
 
71
  if args.task == "generation":