Spaces:
Running
Running
File size: 5,510 Bytes
9ff79dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
USE_LOCAL_DATASET = os.environ.get("USE_LOCAL_DATASET", "1") == "1"
def add_metadata_column(dataset, column_name, value):
def add_source(example):
example[column_name] = value
return example
return dataset.map(add_source)
def load_train_set() -> DatasetDict:
ds_paths = [
"infovqa_train",
"docvqa_train",
"arxivqa_train",
"tatdqa_train",
"syntheticDocQA_government_reports_train",
"syntheticDocQA_healthcare_industry_train",
"syntheticDocQA_artificial_intelligence_train",
"syntheticDocQA_energy_train",
]
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
ds_tot = []
for path in ds_paths:
cpath = base_path + path
ds = load_dataset(cpath, split="train")
if "arxivqa" in path:
# subsample 10k
ds = ds.shuffle(42).select(range(10000))
ds_tot.append(ds)
dataset = concatenate_datasets(ds_tot)
dataset = dataset.shuffle(seed=42)
# split into train and test
dataset_eval = dataset.select(range(500))
dataset = dataset.select(range(500, len(dataset)))
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
return ds_dict
def load_train_set_with_tabfquad() -> DatasetDict:
ds_paths = [
"infovqa_train",
"docvqa_train",
"arxivqa_train",
"tatdqa_train",
"tabfquad_train_subsampled",
"syntheticDocQA_government_reports_train",
"syntheticDocQA_healthcare_industry_train",
"syntheticDocQA_artificial_intelligence_train",
"syntheticDocQA_energy_train",
]
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
ds_tot = []
for path in ds_paths:
cpath = base_path + path
ds = load_dataset(cpath, split="train")
if "arxivqa" in path:
# subsample 10k
ds = ds.shuffle(42).select(range(10000))
ds_tot.append(ds)
dataset = concatenate_datasets(ds_tot)
dataset = dataset.shuffle(seed=42)
# split into train and test
dataset_eval = dataset.select(range(500))
dataset = dataset.select(range(500, len(dataset)))
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
return ds_dict
def load_train_set_with_docmatix() -> DatasetDict:
ds_paths = [
"infovqa_train",
"docvqa_train",
"arxivqa_train",
"tatdqa_train",
"tabfquad_train_subsampled",
"syntheticDocQA_government_reports_train",
"syntheticDocQA_healthcare_industry_train",
"syntheticDocQA_artificial_intelligence_train",
"syntheticDocQA_energy_train",
"Docmatix_filtered_train",
]
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
ds_tot = []
for path in ds_paths:
cpath = base_path + path
ds = load_dataset(cpath, split="train")
if "arxivqa" in path:
# subsample 10k
ds = ds.shuffle(42).select(range(10000))
ds_tot.append(ds)
dataset = concatenate_datasets(ds_tot)
dataset = dataset.shuffle(seed=42)
# split into train and test
dataset_eval = dataset.select(range(500))
dataset = dataset.select(range(500, len(dataset)))
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
return ds_dict
def load_docvqa_dataset() -> DatasetDict:
if USE_LOCAL_DATASET:
dataset_doc = load_dataset("./data_dir/DocVQA", "DocVQA", split="validation")
dataset_doc_eval = load_dataset("./data_dir/DocVQA", "DocVQA", split="test")
dataset_info = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="validation")
dataset_info_eval = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="test")
else:
dataset_doc = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
dataset_doc_eval = load_dataset("lmms-lab/DocVQA", "DocVQA", split="test")
dataset_info = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation")
dataset_info_eval = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test")
# concatenate the two datasets
dataset = concatenate_datasets([dataset_doc, dataset_info])
dataset_eval = concatenate_datasets([dataset_doc_eval, dataset_info_eval])
# sample 100 from eval dataset
dataset_eval = dataset_eval.shuffle(seed=42).select(range(200))
# rename question as query
dataset = dataset.rename_column("question", "query")
dataset_eval = dataset_eval.rename_column("question", "query")
# create new column image_filename that corresponds to ucsf_document_id if not None, else image_url
dataset = dataset.map(
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
)
dataset_eval = dataset_eval.map(
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
)
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
return ds_dict
class TestSetFactory:
def __init__(self, dataset_path):
self.dataset_path = dataset_path
def __call__(self, *args, **kwargs):
dataset = load_dataset(self.dataset_path, split="test")
return dataset
if __name__ == "__main__":
ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
print(ds)
|