Spaces:
Runtime error
Runtime error
from typing import List, Tuple, Callable | |
from pathlib import Path | |
import PIL.Image | |
import numpy as np | |
import datasets | |
import torch | |
from torch.utils.data import Dataset | |
class SegmentationDataset(Dataset): | |
def __init__( | |
self, | |
root: str, | |
subset: str, | |
transform: Callable = None, | |
target_transform: Callable = None, | |
) -> None: | |
super().__init__() | |
self.images_dir = Path(root) / "images" / subset | |
self.masks_dir = Path(root) / "annotations" / subset | |
self.transform = transform | |
self.target_transform = target_transform | |
self.images = sorted(list(Path(self.images_dir).glob("**/*.jpg"))) | |
self.masks = sorted(list(Path(self.masks_dir).glob("**/*.png"))) | |
def __len__(self) -> int: | |
return len(self.images) | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
image = PIL.Image.open(self.images[idx]).convert("RGB") | |
mask = PIL.Image.open(self.masks[idx]).convert("L") | |
if self.transform: | |
image = self.transform(image) | |
if self.target_transform: | |
mask = self.target_transform(mask) | |
return image, mask | |
def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]: | |
images = torch.stack([item[0] for item in items]) | |
masks = torch.stack([item[1] for item in items]) | |
return images, masks | |