File size: 2,946 Bytes
aea73e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# -*- coding: utf-8 -*-
# Patched WSI Dataset used for inference, mainly for calculating embeddings
#
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen

from typing import Callable, Tuple, List

import torch
from torch.utils.data import Dataset
from datamodel.wsi_datamodel import WSI


class PatchedWSIInference(Dataset):
    """Inference Dataset, used for calculating embeddings of *one* WSI. Wrapped around a WSI object

    Args:
        wsi_object (
        filelist (list[str]): List with filenames as entries. Filenames should match the key pattern in wsi_objects dictionary
        transform (Callable): Inference Transformations
    """

    def __init__(
        self,
        wsi_object: WSI,
        transform: Callable,
    ) -> None:
        # set all configurations
        assert isinstance(wsi_object, WSI), "Must be a WSI-object"
        assert (
            wsi_object.patched_slide_path is not None
        ), "Please provide a WSI that already has been patched into slices"

        self.transform = transform
        self.wsi_object = wsi_object

    def __getitem__(
        self, idx: int
    ) -> Tuple[torch.Tensor, list[list[str, str]], list[str], int, str]:
        """Returns one WSI with patches, coords, filenames, labels and wsi name for given idx

        Args:
            idx (int): Index of WSI to retrieve

        Returns:
            Tuple[torch.Tensor, list[list[str,str]], list[str], int, str]:

            * torch.Tensor: Tensor with shape [num_patches, 3, height, width], includes all patches for one WSI
            * list[list[str,str]]: List with coordinates as list entries, e.g., [['1', '1'], ['2', '1'], ..., ['row', 'col']]
            * list[str]: List with patch filenames
            * int: Patient label as integer
            * str: String with WSI name
        """
        patch_name = self.wsi_object.patches_list[idx]

        patch, metadata = self.wsi_object.process_patch_image(
            patch_name=patch_name, transform=self.transform
        )

        return patch, metadata

    def __len__(self) -> int:
        """Return len of dataset

        Returns:
            int: Len of dataset
        """
        return int(self.wsi_object.get_number_patches())

    @staticmethod
    def collate_batch(batch: List[Tuple]) -> Tuple[torch.Tensor, list[dict]]:
        """Create a custom batch

        Needed to unpack List of tuples with dictionaries and array

        Args:
            batch (List[Tuple]): Input batch consisting of a list of tuples (patch, patch-metadata)

        Returns:
            Tuple[torch.Tensor, list[dict]]:
                New batch: patches with shape [batch_size, 3, patch_size, patch_size], list of metadata dicts
        """
        patches, metadata = zip(*batch)
        patches = torch.stack(patches)
        metadata = list(metadata)
        return patches, metadata