File size: 1,469 Bytes
b108d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import List, Tuple, Callable
from pathlib import Path
import PIL.Image
import numpy as np
import datasets
import torch
from torch.utils.data import Dataset


class SegmentationDataset(Dataset):
    def __init__(

        self,

        root: str,

        subset: str,

        transform: Callable = None,

        target_transform: Callable = None,

    ) -> None:
        super().__init__()
        self.images_dir = Path(root) / "images" / subset
        self.masks_dir = Path(root) / "annotations" / subset
        self.transform = transform
        self.target_transform = target_transform

        self.images = sorted(list(Path(self.images_dir).glob("**/*.jpg")))
        self.masks = sorted(list(Path(self.masks_dir).glob("**/*.png")))

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = PIL.Image.open(self.images[idx]).convert("RGB")
        mask = PIL.Image.open(self.masks[idx]).convert("L")
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        return image, mask


def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
    images = torch.stack([item[0] for item in items])
    masks = torch.stack([item[1] for item in items])
    return images, masks