Spaces:
Running
on
Zero
Running
on
Zero
from pathlib import Path | |
from typing import Optional | |
import numpy as np | |
import PIL.Image as Image | |
from torch.utils.data import Dataset | |
from vhap.util.log import get_logger | |
logger = get_logger(__name__) | |
class ImageFolderDataset(Dataset): | |
def __init__( | |
self, | |
image_folder: Path, | |
background_folder: Optional[Path]=None, | |
background_fname2camId=lambda x: x, | |
image_fname2camId=lambda x: x, | |
): | |
""" | |
Args: | |
root_folder: Path to dataset with the following directory layout | |
<image_folder>/ | |
|---xx.jpg | |
|---... | |
""" | |
super().__init__() | |
self.image_fname2camId = image_fname2camId | |
self.background_foler = background_folder | |
logger.info(f"Initializing dataset from folder {image_folder}") | |
self.image_paths = sorted(list(image_folder.glob('*.jpg'))) | |
if background_folder is not None: | |
self.backgrounds = {} | |
background_paths = sorted(list((image_folder / background_folder).glob('*.jpg'))) | |
for background_path in background_paths: | |
bg = np.array(Image.open(background_path)) | |
cam_id = background_fname2camId(background_path.name) | |
self.backgrounds[cam_id] = bg | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, i): | |
image_path = self.image_paths[i] | |
cam_id = self.image_fname2camId(image_path.name) | |
rgb = np.array(Image.open(image_path)) | |
item = { | |
"rgb": rgb, | |
'image_path': str(image_path), | |
} | |
if self.background_foler is not None: | |
item['background'] = self.backgrounds[cam_id] | |
return item | |
if __name__ == "__main__": | |
from tqdm import tqdm | |
from torch.utils.data import DataLoader | |
dataset = ImageFolderDataset( | |
image_folder='./xx', | |
img_to_tensor=True, | |
) | |
print(len(dataset)) | |
sample = dataset[0] | |
print(sample.keys()) | |
print(sample["rgb"].shape) | |
dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) | |
for item in tqdm(dataloader): | |
pass | |