Spaces:
Running
on
T4
Running
on
T4
File size: 2,520 Bytes
0c3992e a00d62c 0c3992e a00d62c 0c3992e a00d62c 0c3992e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import copy
import os.path as osp
import torch
import pandas as pd
class STaRKDataset:
def __init__(self, query_dir, split_dir, human_generated_eval=False):
self.query_dir = query_dir
self.split_dir = split_dir
self.human_generated_eval = human_generated_eval
if human_generated_eval:
self.qa_csv_path = osp.join(query_dir, 'stark_qa_human_generated_eval.csv')
else:
self.qa_csv_path = osp.join(query_dir, 'stark_qa.csv')
print('Loading QA dataset from', self.qa_csv_path)
self.data = pd.read_csv(self.qa_csv_path)
self.indices = list(self.data['id'])
self.indices.sort()
self.split_indices = self.get_idx_split()
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
q_id = self.indices[idx]
meta_info = None
row = self.data[self.data['id'] == q_id].iloc[0]
query = row['query']
answer_ids = eval(row['answer_ids'])
return query, q_id, answer_ids, meta_info
def get_idx_split(self, test_ratio=1.0):
'''
Return the indices of train/val/test split in a dictionary.
'''
if self.human_generated_eval:
return {'human_generated_eval': torch.LongTensor(self.indices)}
split_idx = {}
for split in ['train', 'val', 'test']:
# `{split}.index`stores query ids, not the index in the dataset
indices_file = osp.join(self.split_dir, f'{split}.index')
indices = open(indices_file, 'r').read().strip().split('\n')
query_ids = [int(idx) for idx in indices]
split_idx[split] = torch.LongTensor([self.indices.index(query_id) for query_id in query_ids])
if test_ratio < 1.0:
split_idx['test'] = split_idx['test'][:int(len(split_idx['test']) * test_ratio)]
return split_idx
def get_query_by_qid(self, q_id):
'''
Return the query by query id.
'''
row = self.data[self.data['id'] == q_id].iloc[0]
return row['query']
def get_subset(self, split):
'''
Return a subset of the dataset.
'''
assert split in ['train', 'val', 'test']
indices_file = osp.join(self.split_dir, f'{split}.index')
indices = open(indices_file, 'r').read().strip().split('\n')
subset = copy.deepcopy(self)
subset.indices = [int(idx) for idx in indices]
return subset
|