import cv2 |
import numpy as np |
import torch |
from openslide import OpenSlide |
def extract_tissue_patch_coords( |
wsi_path: str, |
patch_size: int = 256, |
step_size: int = 256, |
downsample_threshold: float = 64, |
threshold: int = 8, |
max_val: int = 255, |
median_kernel: int = 7, |
close_size: int = 4, |
min_effective_area_factor: float = 100, |
ref_area: int = 512, |
min_hole_area_factor: float = 16, |
max_n_holes: int = 8, |
)-> list: |
""" |
Extract patches from the full-resolution image whose centers fall within tissue regions. |
Process: |
1. Open the WSI. |
2. Select a segmentation level and compute a binary mask. |
3. Find contours and holes from the mask and filter them using effective area criteria. |
4. Scale the external contours and holes to full resolution. |
5. Slide a window over the full-resolution image and extract patches if the center is in tissue. |
Returns: |
A torch tensor of shape (N, 3, patch_size, patch_size) containing the patches. |
""" |
slide = OpenSlide(wsi_path) |
full_width, full_height = slide.level_dimensions[0] |
seg_level, scale = select_segmentation_level(slide, downsample_threshold) |
binary_mask = compute_segmentation_mask( |
slide, seg_level, threshold, max_val, median_kernel, close_size |
) |
effective_area_thresh = min_effective_area_factor * ( |
ref_area**2 / (scale[0] * scale[1]) |
) |
hole_area_thresh = min_hole_area_factor * (ref_area**2 / (scale[0] * scale[1])) |
ext_contours, holes_list = filter_contours_and_holes( |
binary_mask, effective_area_thresh, hole_area_thresh, max_n_holes |
) |
if not ext_contours: |
raise ValueError("No valid tissue contours found.") |
tissue_contours = scale_contours(ext_contours, scale) |
scaled_holes = [scale_contours(holes, scale) for holes in holes_list] |
coords = [] |
for y in range(0, full_height - patch_size + 1, step_size): |
for x in range(0, full_width - patch_size + 1, step_size): |
center_x = x + patch_size // 2 |
center_y = y + patch_size // 2 |
if not point_in_tissue(center_x, center_y, tissue_contours, scaled_holes): |
continue |
coords.append((x, y)) |
if not coords: |
raise ValueError("No available patches") |
return coords |
def select_segmentation_level(slide: OpenSlide, downsample_threshold: float = 64): |
""" |
Select a segmentation level whose downsample factor is at least the specified threshold. |
Returns: |
level (int): Chosen level index. |
scale (tuple): Downsample factors (sx, sy) for that level. |
""" |
level = slide.get_best_level_for_downsample(downsample_threshold) |
ds = slide.level_downsamples[level] |
if not isinstance(ds, (tuple, list)): |
ds = (ds, ds) |
return level, ds |
def compute_segmentation_mask( |
slide: OpenSlide, |
level: int, |
threshold: int = 20, |
max_val: int = 255, |
median_kernel: int = 7, |
close_size: int = 4, |
): |
""" |
Compute a binary mask for tissue segmentation at the specified level. |
Process: |
- Read the image at the given level and convert to RGB. |
- Convert the image to HSV and extract the saturation channel. |
- Apply median blur. |
- Apply binary thresholding (either fixed or Otsu). |
- Apply morphological closing. |
Returns: |
binary (ndarray): Binary mask image. |
""" |
img = np.array( |
slide.read_region((0, 0), level, slide.level_dimensions[level]).convert("RGB") |
) |
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
sat = hsv[:, :, 1] |
blurred = cv2.medianBlur(sat, median_kernel) |
_, binary = cv2.threshold(blurred, threshold, max_val, cv2.THRESH_BINARY) |
if close_size > 0: |
kernel = np.ones((close_size, close_size), np.uint8) |
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel) |
return binary |
def filter_contours_and_holes( |
binary_mask: np.ndarray, |
min_effective_area: float, |
min_hole_area: float, |
max_n_holes: int, |
): |
""" |
Find contours from the binary mask and filter them based on effective area. |
For each external contour (one with no parent), identify child contours (holes), |
sort them by area (largest first), and keep up to max_n_holes that exceed min_hole_area. |
The effective area is computed as the area of the external contour minus the sum of areas |
of the selected holes. Only contours with effective area above min_effective_area are retained. |
Returns: |
filtered_contours (list): List of external contours (numpy arrays). |
holes_list (list): Corresponding list of lists of hole contours. |
""" |
contours, hierarchy = cv2.findContours( |
binary_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE |
) |
if hierarchy is None: |
return [], [] |
hierarchy = hierarchy[0] |
filtered_contours = [] |
holes_list = [] |
for idx, h in enumerate(hierarchy): |
if h[3] != -1: |
continue |
ext_cont = contours[idx] |
ext_area = cv2.contourArea(ext_cont) |
hole_idxs = [i for i, hr in enumerate(hierarchy) if hr[3] == idx] |
sorted_holes = sorted( |
[contours[i] for i in hole_idxs], key=cv2.contourArea, reverse=True |
) |
selected_holes = [ |
hole |
for hole in sorted_holes[:max_n_holes] |
if cv2.contourArea(hole) > min_hole_area |
] |
total_hole_area = sum(cv2.contourArea(hole) for hole in selected_holes) |
effective_area = ext_area - total_hole_area |
if effective_area > min_effective_area: |
filtered_contours.append(ext_cont) |
holes_list.append(selected_holes) |
return filtered_contours, holes_list |
def scale_contours(contours: list, scale: tuple) -> list: |
""" |
Scale contour coordinates by the provided scale factors. |
Args: |
contours: List of contours (each a numpy array of points). |
scale: Tuple (sx, sy) for scaling. |
Returns: |
List of scaled contours. |
""" |
scaled = [] |
for cont in contours: |
scaled.append((cont * np.array(scale, dtype=np.float32)).astype(np.int32)) |
return scaled |
def point_in_tissue(x: int, y: int, ext_contours: list, holes_list: list) -> bool: |
""" |
Check if point (x, y) lies within any external contour and not inside its corresponding holes. |
For each external contour in ext_contours (paired with holes_list), |
if the point is inside the contour and not inside any of its holes, return True. |
""" |
for cont, holes in zip(ext_contours, holes_list): |
if cv2.pointPolygonTest(cont, (x, y), False) >= 0: |
inside_hole = False |
for hole in holes: |
if cv2.pointPolygonTest(hole, (x, y), False) >= 0: |
inside_hole = True |
break |
if not inside_hole: |
return True |
return False |
def tile(x: torch.Tensor, size: int): |
C, H, W = x.shape[-3:] |
pad_h = (size - H % size) % size |
pad_w = (size - W % size) % size |
if pad_h > 0 or pad_w > 0: |
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h)) |
nh, nw = x.size(2) // size, x.size(3) // size |
return ( |
x.view(-1, C, nh, size, nw, size) |
.permute(0, 2, 4, 1, 3, 5) |
.reshape(-1, C, size, size) |
) |