Vishwas1 commited on
Commit
958029a
·
verified ·
1 Parent(s): 27f2ab5

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +4 -2
train_model.py CHANGED
@@ -65,12 +65,14 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
65
  logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
66
  try:
67
  if task == "generation":
68
- dataset = load_dataset(dataset_name, split='train',use_auth_token=True)
 
69
  logging.info("Dataset loaded successfully for generation task.")
70
  def tokenize_function(examples):
71
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
72
  elif task == "classification":
73
- dataset = load_dataset(dataset_name, split='train')
 
74
  logging.info("Dataset loaded successfully for classification task.")
75
  # Assuming the dataset has 'text' and 'label' columns
76
  def tokenize_function(examples):
 
65
  logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
66
  try:
67
  if task == "generation":
68
+ train_dataset = load_dataset(dataset_name,use_auth_token=True)
69
+ dataset = train_dataset['train'].shuffle(seed=42).select(range(500))
70
  logging.info("Dataset loaded successfully for generation task.")
71
  def tokenize_function(examples):
72
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
73
  elif task == "classification":
74
+ train_dataset = load_dataset(dataset_name,use_auth_token=True)
75
+ dataset = train_dataset['train'].shuffle(seed=42).select(range(500))
76
  logging.info("Dataset loaded successfully for classification task.")
77
  # Assuming the dataset has 'text' and 'label' columns
78
  def tokenize_function(examples):