Vishwas1 commited on
Commit
791abc9
·
verified ·
1 Parent(s): 3266823

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +4 -4
train_model.py CHANGED
@@ -68,18 +68,18 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
68
  # Check if dataset_name includes config
69
  if '/' in dataset_name:
70
  dataset, config = dataset_name.split('/', 1)
71
- dataset = load_dataset(dataset, config, split='train[:1%]', use_auth_token=True)
72
  else:
73
- dataset = load_dataset(dataset_name, split='train[:1%]', use_auth_token=True)
74
  logging.info("Dataset loaded successfully for generation task.")
75
  def tokenize_function(examples):
76
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
77
  elif task == "classification":
78
  if '/' in dataset_name:
79
  dataset, config = dataset_name.split('/', 1)
80
- dataset = load_dataset(dataset, config, split='train[:1%]', use_auth_token=True)
81
  else:
82
- dataset = load_dataset(dataset_name, split='train[:1%]', use_auth_token=True)
83
  logging.info("Dataset loaded successfully for classification task.")
84
  # Assuming the dataset has 'text' and 'label' columns
85
  def tokenize_function(examples):
 
68
  # Check if dataset_name includes config
69
  if '/' in dataset_name:
70
  dataset, config = dataset_name.split('/', 1)
71
+ dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train[:1%]', use_auth_token=True)
72
  else:
73
+ dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train[:1%]', use_auth_token=True)
74
  logging.info("Dataset loaded successfully for generation task.")
75
  def tokenize_function(examples):
76
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
77
  elif task == "classification":
78
  if '/' in dataset_name:
79
  dataset, config = dataset_name.split('/', 1)
80
+ dataset = load_dataset("stanfordnlp/imdb", split='train[:1%]', use_auth_token=True)
81
  else:
82
+ dataset = load_dataset("stanfordnlp/imdb", split='train[:1%]', use_auth_token=True)
83
  logging.info("Dataset loaded successfully for classification task.")
84
  # Assuming the dataset has 'text' and 'label' columns
85
  def tokenize_function(examples):