Vishwas1 commited on
Commit
ef223be
·
verified ·
1 Parent(s): 55f1be4

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +4 -4
train_model.py CHANGED
@@ -70,18 +70,18 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
70
  # Check if dataset_name includes a configuration
71
  if '/' in dataset_name:
72
  dataset, config = dataset_name.split('/', 1)
73
- dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train', use_auth_token=True)
74
  else:
75
- dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train', use_auth_token=True)
76
  logging.info("Dataset loaded successfully for generation task.")
77
  def tokenize_function(examples):
78
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
79
  elif task == "classification":
80
  if '/' in dataset_name:
81
  dataset, config = dataset_name.split('/', 1)
82
- dataset = load_dataset(dataset, config, split='train', use_auth_token=True)
83
  else:
84
- dataset = load_dataset(dataset_name, split='train', use_auth_token=True)
85
  logging.info("Dataset loaded successfully for classification task.")
86
  # Assuming the dataset has 'text' and 'label' columns
87
  def tokenize_function(examples):
 
70
  # Check if dataset_name includes a configuration
71
  if '/' in dataset_name:
72
  dataset, config = dataset_name.split('/', 1)
73
+ dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
74
  else:
75
+ dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
76
  logging.info("Dataset loaded successfully for generation task.")
77
  def tokenize_function(examples):
78
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
79
  elif task == "classification":
80
  if '/' in dataset_name:
81
  dataset, config = dataset_name.split('/', 1)
82
+ dataset = load_dataset(dataset, config, split='train')
83
  else:
84
+ dataset = load_dataset(dataset_name, split='train')
85
  logging.info("Dataset loaded successfully for classification task.")
86
  # Assuming the dataset has 'text' and 'label' columns
87
  def tokenize_function(examples):