File size: 1,347 Bytes
96b6673 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from torch.utils.data import Dataset
import json
default_get = lambda key : lambda data: data[key]
class PromptDataset(Dataset):
def __init__(self,data_dir,*keys,**projections) -> None:
self.data = []
for d in data_dir:
list_contents = {key:default_get(key)(d) for key in keys if key in d.keys()}
dict_contents = {projection:projections[projection](d) for projection in projections.keys()}
self.data.append({**list_contents,**dict_contents})
def __getitem__(self, index) -> dict:
return self.data[index]
def __len__(self):
return len(self.data)
class FileDataset(PromptDataset):
def __init__(self,data_dir,*keys,**projections) -> None:
with open(data_dir,'r',encoding='utf-8') as file:
data_dir = json.load(file)
if not keys:
keys = data_dir[0].keys()
self.data = []
for d in data_dir:
list_contents = {key:default_get(key)(d) for key in keys if key in d.keys()}
dict_contents = {projection:projections[projection](d) for projection in projections.keys()}
self.data.append({**list_contents,**dict_contents})
def __getitem__(self, index) -> dict:
return self.data[index]
def __len__(self):
return len(self.data) |