Spaces:
Build error
Build error
| from typing import Any, Literal, Union | |
| from pathlib import Path | |
| import jsonlines | |
| from PIL import Image | |
| from torch import Tensor | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| import os | |
| import json | |
| from src.utils import bbox_augmentation_resize | |
| class PubTables(Dataset): | |
| """PubTables-1M-Structure""" | |
| def __init__( | |
| self, | |
| root_dir: Union[Path, str], | |
| label_type: Literal["image", "cell", "bbox"], | |
| split: Literal["train", "val", "test"], | |
| transform: transforms = None, | |
| cell_limit: int = 100, | |
| ) -> None: | |
| super().__init__() | |
| self.root_dir = Path(root_dir) | |
| self.split = split | |
| self.label_type = label_type | |
| self.transform = transform | |
| self.cell_limit = cell_limit | |
| tmp = os.listdir(self.root_dir / self.split) | |
| self.image_list = [i.split(".xml")[0] for i in tmp] | |
| def __len__(self): | |
| return len(self.image_list) | |
| def __getitem__(self, index: int) -> Any: | |
| name = self.image_list[index] | |
| img = Image.open(os.path.join(self.root_dir, "images", name + ".jpg")) | |
| if self.label_type == "image": | |
| if self.transform: | |
| img = self.transform(img) | |
| return img | |
| elif "bbox" in self.label_type: | |
| img_size = img.size | |
| if self.transform: | |
| img = self.transform(img) | |
| tgt_size = img.shape[-1] | |
| with open( | |
| os.path.join(self.root_dir, "words", name + "_words.json"), "r" | |
| ) as f: | |
| obj = json.load(f) | |
| obj[:] = [ | |
| v | |
| for i in obj | |
| if "bbox" in i.keys() | |
| and all([i["bbox"][w + 2] > i["bbox"][w] for w in range(2)]) | |
| for v in bbox_augmentation_resize( | |
| [ | |
| min(max(i["bbox"][0], 0), img_size[0]), | |
| min(max(i["bbox"][1], 0), img_size[1]), | |
| min(max(i["bbox"][2], 0), img_size[0]), | |
| min(max(i["bbox"][3], 0), img_size[1]), | |
| ], | |
| img_size, | |
| tgt_size, | |
| ) | |
| ] | |
| sample = {"filename": name, "image": img, "bbox": obj} | |
| return sample | |
| elif "cell" in self.label_type: | |
| img_size = img.size | |
| with open( | |
| os.path.join(self.root_dir, "words", name + "_words.json"), "r" | |
| ) as f: | |
| obj = json.load(f) | |
| bboxes_texts = [ | |
| (i["bbox"], i["text"]) | |
| for idx, i in enumerate(obj) | |
| if "bbox" in i | |
| and i["bbox"][0] < i["bbox"][2] | |
| and i["bbox"][1] < i["bbox"][3] | |
| and i["bbox"][0] >= 0 | |
| and i["bbox"][1] >= 0 | |
| and i["bbox"][2] < img_size[0] | |
| and i["bbox"][3] < img_size[1] | |
| and idx < self.cell_limit | |
| ] | |
| img_bboxes = [self.transform(img.crop(bbox[0])) for bbox in bboxes_texts] | |
| text_bboxes = [ | |
| {"filename": name, "bbox_id": i, "cell": j[1]} | |
| for i, j in enumerate(bboxes_texts) | |
| ] | |
| return img_bboxes, text_bboxes | |