|
import cv2 |
|
import numpy as np |
|
import torch |
|
from openslide import OpenSlide |
|
|
|
|
|
def extract_tissue_patch_coords( |
|
wsi_path: str, |
|
patch_size: int = 256, |
|
step_size: int = 256, |
|
downsample_threshold: float = 64, |
|
threshold: int = 8, |
|
max_val: int = 255, |
|
median_kernel: int = 7, |
|
close_size: int = 4, |
|
min_effective_area_factor: float = 100, |
|
ref_area: int = 512, |
|
min_hole_area_factor: float = 16, |
|
max_n_holes: int = 8, |
|
)-> list: |
|
""" |
|
Extract patches from the full-resolution image whose centers fall within tissue regions. |
|
|
|
Process: |
|
1. Open the WSI. |
|
2. Select a segmentation level and compute a binary mask. |
|
3. Find contours and holes from the mask and filter them using effective area criteria. |
|
4. Scale the external contours and holes to full resolution. |
|
5. Slide a window over the full-resolution image and extract patches if the center is in tissue. |
|
|
|
Returns: |
|
A torch tensor of shape (N, 3, patch_size, patch_size) containing the patches. |
|
""" |
|
slide = OpenSlide(wsi_path) |
|
full_width, full_height = slide.level_dimensions[0] |
|
|
|
seg_level, scale = select_segmentation_level(slide, downsample_threshold) |
|
binary_mask = compute_segmentation_mask( |
|
slide, seg_level, threshold, max_val, median_kernel, close_size |
|
) |
|
|
|
|
|
effective_area_thresh = min_effective_area_factor * ( |
|
ref_area**2 / (scale[0] * scale[1]) |
|
) |
|
hole_area_thresh = min_hole_area_factor * (ref_area**2 / (scale[0] * scale[1])) |
|
|
|
ext_contours, holes_list = filter_contours_and_holes( |
|
binary_mask, effective_area_thresh, hole_area_thresh, max_n_holes |
|
) |
|
if not ext_contours: |
|
raise ValueError("No valid tissue contours found.") |
|
|
|
tissue_contours = scale_contours(ext_contours, scale) |
|
scaled_holes = [scale_contours(holes, scale) for holes in holes_list] |
|
|
|
coords = [] |
|
for y in range(0, full_height - patch_size + 1, step_size): |
|
for x in range(0, full_width - patch_size + 1, step_size): |
|
center_x = x + patch_size // 2 |
|
center_y = y + patch_size // 2 |
|
if not point_in_tissue(center_x, center_y, tissue_contours, scaled_holes): |
|
continue |
|
coords.append((x, y)) |
|
|
|
if not coords: |
|
raise ValueError("No available patches") |
|
return coords |
|
|
|
|
|
def select_segmentation_level(slide: OpenSlide, downsample_threshold: float = 64): |
|
""" |
|
Select a segmentation level whose downsample factor is at least the specified threshold. |
|
|
|
Returns: |
|
level (int): Chosen level index. |
|
scale (tuple): Downsample factors (sx, sy) for that level. |
|
""" |
|
level = slide.get_best_level_for_downsample(downsample_threshold) |
|
ds = slide.level_downsamples[level] |
|
if not isinstance(ds, (tuple, list)): |
|
ds = (ds, ds) |
|
return level, ds |
|
|
|
|
|
def compute_segmentation_mask( |
|
slide: OpenSlide, |
|
level: int, |
|
threshold: int = 20, |
|
max_val: int = 255, |
|
median_kernel: int = 7, |
|
close_size: int = 4, |
|
): |
|
""" |
|
Compute a binary mask for tissue segmentation at the specified level. |
|
|
|
Process: |
|
- Read the image at the given level and convert to RGB. |
|
- Convert the image to HSV and extract the saturation channel. |
|
- Apply median blur. |
|
- Apply binary thresholding (either fixed or Otsu). |
|
- Apply morphological closing. |
|
|
|
Returns: |
|
binary (ndarray): Binary mask image. |
|
""" |
|
img = np.array( |
|
slide.read_region((0, 0), level, slide.level_dimensions[level]).convert("RGB") |
|
) |
|
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
|
sat = hsv[:, :, 1] |
|
blurred = cv2.medianBlur(sat, median_kernel) |
|
_, binary = cv2.threshold(blurred, threshold, max_val, cv2.THRESH_BINARY) |
|
if close_size > 0: |
|
kernel = np.ones((close_size, close_size), np.uint8) |
|
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel) |
|
return binary |
|
|
|
|
|
def filter_contours_and_holes( |
|
binary_mask: np.ndarray, |
|
min_effective_area: float, |
|
min_hole_area: float, |
|
max_n_holes: int, |
|
): |
|
""" |
|
Find contours from the binary mask and filter them based on effective area. |
|
|
|
For each external contour (one with no parent), identify child contours (holes), |
|
sort them by area (largest first), and keep up to max_n_holes that exceed min_hole_area. |
|
The effective area is computed as the area of the external contour minus the sum of areas |
|
of the selected holes. Only contours with effective area above min_effective_area are retained. |
|
|
|
Returns: |
|
filtered_contours (list): List of external contours (numpy arrays). |
|
holes_list (list): Corresponding list of lists of hole contours. |
|
""" |
|
contours, hierarchy = cv2.findContours( |
|
binary_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE |
|
) |
|
if hierarchy is None: |
|
return [], [] |
|
hierarchy = hierarchy[0] |
|
filtered_contours = [] |
|
holes_list = [] |
|
for idx, h in enumerate(hierarchy): |
|
if h[3] != -1: |
|
continue |
|
ext_cont = contours[idx] |
|
ext_area = cv2.contourArea(ext_cont) |
|
|
|
hole_idxs = [i for i, hr in enumerate(hierarchy) if hr[3] == idx] |
|
|
|
sorted_holes = sorted( |
|
[contours[i] for i in hole_idxs], key=cv2.contourArea, reverse=True |
|
) |
|
selected_holes = [ |
|
hole |
|
for hole in sorted_holes[:max_n_holes] |
|
if cv2.contourArea(hole) > min_hole_area |
|
] |
|
total_hole_area = sum(cv2.contourArea(hole) for hole in selected_holes) |
|
effective_area = ext_area - total_hole_area |
|
if effective_area > min_effective_area: |
|
filtered_contours.append(ext_cont) |
|
holes_list.append(selected_holes) |
|
return filtered_contours, holes_list |
|
|
|
|
|
def scale_contours(contours: list, scale: tuple) -> list: |
|
""" |
|
Scale contour coordinates by the provided scale factors. |
|
|
|
Args: |
|
contours: List of contours (each a numpy array of points). |
|
scale: Tuple (sx, sy) for scaling. |
|
|
|
Returns: |
|
List of scaled contours. |
|
""" |
|
scaled = [] |
|
for cont in contours: |
|
scaled.append((cont * np.array(scale, dtype=np.float32)).astype(np.int32)) |
|
return scaled |
|
|
|
|
|
def point_in_tissue(x: int, y: int, ext_contours: list, holes_list: list) -> bool: |
|
""" |
|
Check if point (x, y) lies within any external contour and not inside its corresponding holes. |
|
|
|
For each external contour in ext_contours (paired with holes_list), |
|
if the point is inside the contour and not inside any of its holes, return True. |
|
""" |
|
for cont, holes in zip(ext_contours, holes_list): |
|
if cv2.pointPolygonTest(cont, (x, y), False) >= 0: |
|
inside_hole = False |
|
for hole in holes: |
|
if cv2.pointPolygonTest(hole, (x, y), False) >= 0: |
|
inside_hole = True |
|
break |
|
if not inside_hole: |
|
return True |
|
return False |
|
|
|
|
|
def tile(x: torch.Tensor, size: int): |
|
C, H, W = x.shape[-3:] |
|
|
|
pad_h = (size - H % size) % size |
|
pad_w = (size - W % size) % size |
|
if pad_h > 0 or pad_w > 0: |
|
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h)) |
|
|
|
nh, nw = x.size(2) // size, x.size(3) // size |
|
return ( |
|
x.view(-1, C, nh, size, nw, size) |
|
.permute(0, 2, 4, 1, 3, 5) |
|
.reshape(-1, C, size, size) |
|
) |
|
|