|
import os |
|
|
|
from openslide import OpenSlide |
|
from scripts.preprocessor import MacenkoNormalizer, preprocessor |
|
from torch.utils.data import Dataset |
|
|
|
|
|
class WSIPatchDataset(Dataset): |
|
def __init__( |
|
self, |
|
coords, |
|
wsi_path, |
|
pretrained=False, |
|
patch_size=256, |
|
patch_level=0, |
|
macenko=True, |
|
): |
|
self.pretrained = pretrained |
|
self.wsi = OpenSlide(wsi_path) |
|
self.patch_size = patch_size |
|
self.patch_level = patch_level |
|
|
|
if macenko: |
|
normalizer = MacenkoNormalizer( |
|
target_path=os.path.join( |
|
os.path.dirname(os.path.dirname(os.path.join(__file__))), |
|
"models", |
|
"macenko_param.pt", |
|
) |
|
) |
|
else: |
|
normalizer = None |
|
|
|
self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer) |
|
self.coords = coords |
|
self.length = len(self.coords) |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, idx): |
|
coord = self.coords[idx] |
|
img = self.wsi.read_region( |
|
coord, self.patch_level, (self.patch_size, self.patch_size) |
|
).convert("RGB") |
|
img = self.roi_transforms(img) |
|
return img |
|
|