import os from openslide import OpenSlide from scripts.preprocessor import MacenkoNormalizer, preprocessor from torch.utils.data import Dataset class WSIPatchDataset(Dataset): def __init__( self, coords, wsi_path, pretrained=False, patch_size=256, patch_level=0, macenko=True, ): self.pretrained = pretrained self.wsi = OpenSlide(wsi_path) self.patch_size = patch_size self.patch_level = patch_level if macenko: normalizer = MacenkoNormalizer( target_path=os.path.join( os.path.dirname(os.path.dirname(os.path.join(__file__))), "models", "macenko_param.pt", ) ) else: normalizer = None self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer) self.coords = coords self.length = len(self.coords) def __len__(self): return self.length def __getitem__(self, idx): coord = self.coords[idx] img = self.wsi.read_region( coord, self.patch_level, (self.patch_size, self.patch_size) ).convert("RGB") img = self.roi_transforms(img) return img