Spaces:
Build error
Build error
from typing import Any, Literal, Union | |
from pathlib import Path | |
import jsonlines | |
from PIL import Image | |
from torch import Tensor | |
from torch.utils.data import Dataset | |
import torchvision.transforms as transforms | |
import numpy as np | |
import os | |
import json | |
from src.utils import bbox_augmentation_resize | |
class PubTables(Dataset): | |
"""PubTables-1M-Structure""" | |
def __init__( | |
self, | |
root_dir: Union[Path, str], | |
label_type: Literal["image", "cell", "bbox"], | |
split: Literal["train", "val", "test"], | |
transform: transforms = None, | |
cell_limit: int = 100, | |
) -> None: | |
super().__init__() | |
self.root_dir = Path(root_dir) | |
self.split = split | |
self.label_type = label_type | |
self.transform = transform | |
self.cell_limit = cell_limit | |
tmp = os.listdir(self.root_dir / self.split) | |
self.image_list = [i.split(".xml")[0] for i in tmp] | |
def __len__(self): | |
return len(self.image_list) | |
def __getitem__(self, index: int) -> Any: | |
name = self.image_list[index] | |
img = Image.open(os.path.join(self.root_dir, "images", name + ".jpg")) | |
if self.label_type == "image": | |
if self.transform: | |
img = self.transform(img) | |
return img | |
elif "bbox" in self.label_type: | |
img_size = img.size | |
if self.transform: | |
img = self.transform(img) | |
tgt_size = img.shape[-1] | |
with open( | |
os.path.join(self.root_dir, "words", name + "_words.json"), "r" | |
) as f: | |
obj = json.load(f) | |
obj[:] = [ | |
v | |
for i in obj | |
if "bbox" in i.keys() | |
and all([i["bbox"][w + 2] > i["bbox"][w] for w in range(2)]) | |
for v in bbox_augmentation_resize( | |
[ | |
min(max(i["bbox"][0], 0), img_size[0]), | |
min(max(i["bbox"][1], 0), img_size[1]), | |
min(max(i["bbox"][2], 0), img_size[0]), | |
min(max(i["bbox"][3], 0), img_size[1]), | |
], | |
img_size, | |
tgt_size, | |
) | |
] | |
sample = {"filename": name, "image": img, "bbox": obj} | |
return sample | |
elif "cell" in self.label_type: | |
img_size = img.size | |
with open( | |
os.path.join(self.root_dir, "words", name + "_words.json"), "r" | |
) as f: | |
obj = json.load(f) | |
bboxes_texts = [ | |
(i["bbox"], i["text"]) | |
for idx, i in enumerate(obj) | |
if "bbox" in i | |
and i["bbox"][0] < i["bbox"][2] | |
and i["bbox"][1] < i["bbox"][3] | |
and i["bbox"][0] >= 0 | |
and i["bbox"][1] >= 0 | |
and i["bbox"][2] < img_size[0] | |
and i["bbox"][3] < img_size[1] | |
and idx < self.cell_limit | |
] | |
img_bboxes = [self.transform(img.crop(bbox[0])) for bbox in bboxes_texts] | |
text_bboxes = [ | |
{"filename": name, "bbox_id": i, "cell": j[1]} | |
for i, j in enumerate(bboxes_texts) | |
] | |
return img_bboxes, text_bboxes | |