gourisankar85 commited on
Commit
b9219fb
·
verified ·
1 Parent(s): 34d3a67

Upload load_dataset.py

Browse files
Files changed (1) hide show
  1. data/load_dataset.py +25 -23
data/load_dataset.py CHANGED
@@ -1,23 +1,25 @@
1
- import os
2
- import logging
3
- from datasets import load_dataset
4
- import pickle # For saving the dataset locally
5
-
6
- def load_data(data_set_name, local_path="local_datasets"):
7
- os.makedirs(local_path, exist_ok=True)
8
- dataset_file = os.path.join(local_path, f"{data_set_name}_test.pkl")
9
-
10
- if os.path.exists(dataset_file):
11
- logging.info("Loading dataset from local storage")
12
- with open(dataset_file, "rb") as f:
13
- dataset = pickle.load(f)
14
- else:
15
- logging.info("Loading dataset from Hugging Face")
16
- dataset = load_dataset("rungalileo/ragbench", data_set_name, split="test")
17
- logging.info(f"Saving {data_set_name} dataset locally")
18
- with open(dataset_file, "wb") as f:
19
- pickle.dump(dataset, f)
20
-
21
- logging.info("Dataset loaded successfully")
22
- logging.info(f"Number of documents found: {dataset.num_rows}")
23
- return dataset
 
 
 
1
+ import os
2
+ import logging
3
+ import pickle
4
+ from datasets import load_dataset
5
+ from config import ConfigConstants # For saving the dataset locally
6
+
7
+ def load_data(data_set_name, local_path=ConfigConstants.DATA_SET_PATH):
8
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
9
+ dataset_file = os.path.join(local_path, f"{data_set_name}_test.pkl")
10
+
11
+ if os.path.exists(dataset_file):
12
+ logging.info("Loading dataset {data_set_name} from local storage")
13
+ with open(dataset_file, "rb") as f:
14
+ dataset = pickle.load(f)
15
+ else:
16
+ logging.info("Loading dataset from Hugging Face")
17
+ dataset = load_dataset("rungalileo/ragbench", data_set_name, split="test")
18
+ logging.info(f"Saving {data_set_name} dataset locally")
19
+ with open(dataset_file, "wb") as f:
20
+ pickle.dump(dataset, f)
21
+
22
+ logging.info("Dataset loaded successfully")
23
+ logging.info(f"Number of documents found: {dataset.num_rows}")
24
+ return dataset
25
+