Spaces:
Running
Running
File size: 2,946 Bytes
aea73e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# -*- coding: utf-8 -*-
# Patched WSI Dataset used for inference, mainly for calculating embeddings
#
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
from typing import Callable, Tuple, List
import torch
from torch.utils.data import Dataset
from datamodel.wsi_datamodel import WSI
class PatchedWSIInference(Dataset):
"""Inference Dataset, used for calculating embeddings of *one* WSI. Wrapped around a WSI object
Args:
wsi_object (
filelist (list[str]): List with filenames as entries. Filenames should match the key pattern in wsi_objects dictionary
transform (Callable): Inference Transformations
"""
def __init__(
self,
wsi_object: WSI,
transform: Callable,
) -> None:
# set all configurations
assert isinstance(wsi_object, WSI), "Must be a WSI-object"
assert (
wsi_object.patched_slide_path is not None
), "Please provide a WSI that already has been patched into slices"
self.transform = transform
self.wsi_object = wsi_object
def __getitem__(
self, idx: int
) -> Tuple[torch.Tensor, list[list[str, str]], list[str], int, str]:
"""Returns one WSI with patches, coords, filenames, labels and wsi name for given idx
Args:
idx (int): Index of WSI to retrieve
Returns:
Tuple[torch.Tensor, list[list[str,str]], list[str], int, str]:
* torch.Tensor: Tensor with shape [num_patches, 3, height, width], includes all patches for one WSI
* list[list[str,str]]: List with coordinates as list entries, e.g., [['1', '1'], ['2', '1'], ..., ['row', 'col']]
* list[str]: List with patch filenames
* int: Patient label as integer
* str: String with WSI name
"""
patch_name = self.wsi_object.patches_list[idx]
patch, metadata = self.wsi_object.process_patch_image(
patch_name=patch_name, transform=self.transform
)
return patch, metadata
def __len__(self) -> int:
"""Return len of dataset
Returns:
int: Len of dataset
"""
return int(self.wsi_object.get_number_patches())
@staticmethod
def collate_batch(batch: List[Tuple]) -> Tuple[torch.Tensor, list[dict]]:
"""Create a custom batch
Needed to unpack List of tuples with dictionaries and array
Args:
batch (List[Tuple]): Input batch consisting of a list of tuples (patch, patch-metadata)
Returns:
Tuple[torch.Tensor, list[dict]]:
New batch: patches with shape [batch_size, 3, patch_size, patch_size], list of metadata dicts
"""
patches, metadata = zip(*batch)
patches = torch.stack(patches)
metadata = list(metadata)
return patches, metadata
|