Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Patched WSI Dataset used for inference, mainly for calculating embeddings | |
# | |
# @ Fabian Hörst, [email protected] | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
from typing import Callable, Tuple, List | |
import torch | |
from torch.utils.data import Dataset | |
from datamodel.wsi_datamodel import WSI | |
class PatchedWSIInference(Dataset): | |
"""Inference Dataset, used for calculating embeddings of *one* WSI. Wrapped around a WSI object | |
Args: | |
wsi_object ( | |
filelist (list[str]): List with filenames as entries. Filenames should match the key pattern in wsi_objects dictionary | |
transform (Callable): Inference Transformations | |
""" | |
def __init__( | |
self, | |
wsi_object: WSI, | |
transform: Callable, | |
) -> None: | |
# set all configurations | |
assert isinstance(wsi_object, WSI), "Must be a WSI-object" | |
assert ( | |
wsi_object.patched_slide_path is not None | |
), "Please provide a WSI that already has been patched into slices" | |
self.transform = transform | |
self.wsi_object = wsi_object | |
def __getitem__( | |
self, idx: int | |
) -> Tuple[torch.Tensor, list[list[str, str]], list[str], int, str]: | |
"""Returns one WSI with patches, coords, filenames, labels and wsi name for given idx | |
Args: | |
idx (int): Index of WSI to retrieve | |
Returns: | |
Tuple[torch.Tensor, list[list[str,str]], list[str], int, str]: | |
* torch.Tensor: Tensor with shape [num_patches, 3, height, width], includes all patches for one WSI | |
* list[list[str,str]]: List with coordinates as list entries, e.g., [['1', '1'], ['2', '1'], ..., ['row', 'col']] | |
* list[str]: List with patch filenames | |
* int: Patient label as integer | |
* str: String with WSI name | |
""" | |
patch_name = self.wsi_object.patches_list[idx] | |
patch, metadata = self.wsi_object.process_patch_image( | |
patch_name=patch_name, transform=self.transform | |
) | |
return patch, metadata | |
def __len__(self) -> int: | |
"""Return len of dataset | |
Returns: | |
int: Len of dataset | |
""" | |
return int(self.wsi_object.get_number_patches()) | |
def collate_batch(batch: List[Tuple]) -> Tuple[torch.Tensor, list[dict]]: | |
"""Create a custom batch | |
Needed to unpack List of tuples with dictionaries and array | |
Args: | |
batch (List[Tuple]): Input batch consisting of a list of tuples (patch, patch-metadata) | |
Returns: | |
Tuple[torch.Tensor, list[dict]]: | |
New batch: patches with shape [batch_size, 3, patch_size, patch_size], list of metadata dicts | |
""" | |
patches, metadata = zip(*batch) | |
patches = torch.stack(patches) | |
metadata = list(metadata) | |
return patches, metadata | |