편명장/님/(myeongjang.pyeon)
initial commit
287a683
import os
from openslide import OpenSlide
from scripts.preprocessor import MacenkoNormalizer, preprocessor
from torch.utils.data import Dataset
class WSIPatchDataset(Dataset):
def __init__(
self,
coords,
wsi_path,
pretrained=False,
patch_size=256,
patch_level=0,
macenko=True,
):
self.pretrained = pretrained
self.wsi = OpenSlide(wsi_path)
self.patch_size = patch_size
self.patch_level = patch_level
if macenko:
normalizer = MacenkoNormalizer(
target_path=os.path.join(
os.path.dirname(os.path.dirname(os.path.join(__file__))),
"models",
"macenko_param.pt",
)
)
else:
normalizer = None
self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer)
self.coords = coords
self.length = len(self.coords)
def __len__(self):
return self.length
def __getitem__(self, idx):
coord = self.coords[idx]
img = self.wsi.read_region(
coord, self.patch_level, (self.patch_size, self.patch_size)
).convert("RGB")
img = self.roi_transforms(img)
return img