import cv2 import numpy as np import torch from openslide import OpenSlide def extract_tissue_patch_coords( wsi_path: str, patch_size: int = 256, step_size: int = 256, downsample_threshold: float = 64, threshold: int = 8, max_val: int = 255, median_kernel: int = 7, close_size: int = 4, min_effective_area_factor: float = 100, # multiplied by (ref_area)^2 ref_area: int = 512, min_hole_area_factor: float = 16, # multiplied by (ref_area)^2 max_n_holes: int = 8, )-> list: """ Extract patches from the full-resolution image whose centers fall within tissue regions. Process: 1. Open the WSI. 2. Select a segmentation level and compute a binary mask. 3. Find contours and holes from the mask and filter them using effective area criteria. 4. Scale the external contours and holes to full resolution. 5. Slide a window over the full-resolution image and extract patches if the center is in tissue. Returns: A torch tensor of shape (N, 3, patch_size, patch_size) containing the patches. """ slide = OpenSlide(wsi_path) full_width, full_height = slide.level_dimensions[0] seg_level, scale = select_segmentation_level(slide, downsample_threshold) binary_mask = compute_segmentation_mask( slide, seg_level, threshold, max_val, median_kernel, close_size ) # Compute thresholds for effective area and hole area effective_area_thresh = min_effective_area_factor * ( ref_area**2 / (scale[0] * scale[1]) ) hole_area_thresh = min_hole_area_factor * (ref_area**2 / (scale[0] * scale[1])) ext_contours, holes_list = filter_contours_and_holes( binary_mask, effective_area_thresh, hole_area_thresh, max_n_holes ) if not ext_contours: raise ValueError("No valid tissue contours found.") tissue_contours = scale_contours(ext_contours, scale) scaled_holes = [scale_contours(holes, scale) for holes in holes_list] coords = [] for y in range(0, full_height - patch_size + 1, step_size): for x in range(0, full_width - patch_size + 1, step_size): center_x = x + patch_size // 2 center_y = y + patch_size // 2 if not point_in_tissue(center_x, center_y, tissue_contours, scaled_holes): continue coords.append((x, y)) if not coords: raise ValueError("No available patches") return coords def select_segmentation_level(slide: OpenSlide, downsample_threshold: float = 64): """ Select a segmentation level whose downsample factor is at least the specified threshold. Returns: level (int): Chosen level index. scale (tuple): Downsample factors (sx, sy) for that level. """ level = slide.get_best_level_for_downsample(downsample_threshold) ds = slide.level_downsamples[level] if not isinstance(ds, (tuple, list)): ds = (ds, ds) return level, ds def compute_segmentation_mask( slide: OpenSlide, level: int, threshold: int = 20, max_val: int = 255, median_kernel: int = 7, close_size: int = 4, ): """ Compute a binary mask for tissue segmentation at the specified level. Process: - Read the image at the given level and convert to RGB. - Convert the image to HSV and extract the saturation channel. - Apply median blur. - Apply binary thresholding (either fixed or Otsu). - Apply morphological closing. Returns: binary (ndarray): Binary mask image. """ img = np.array( slide.read_region((0, 0), level, slide.level_dimensions[level]).convert("RGB") ) hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) sat = hsv[:, :, 1] blurred = cv2.medianBlur(sat, median_kernel) _, binary = cv2.threshold(blurred, threshold, max_val, cv2.THRESH_BINARY) if close_size > 0: kernel = np.ones((close_size, close_size), np.uint8) binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel) return binary def filter_contours_and_holes( binary_mask: np.ndarray, min_effective_area: float, min_hole_area: float, max_n_holes: int, ): """ Find contours from the binary mask and filter them based on effective area. For each external contour (one with no parent), identify child contours (holes), sort them by area (largest first), and keep up to max_n_holes that exceed min_hole_area. The effective area is computed as the area of the external contour minus the sum of areas of the selected holes. Only contours with effective area above min_effective_area are retained. Returns: filtered_contours (list): List of external contours (numpy arrays). holes_list (list): Corresponding list of lists of hole contours. """ contours, hierarchy = cv2.findContours( binary_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE ) if hierarchy is None: return [], [] hierarchy = hierarchy[0] # shape: (N, 4) filtered_contours = [] holes_list = [] for idx, h in enumerate(hierarchy): if h[3] != -1: continue # Only external contours ext_cont = contours[idx] ext_area = cv2.contourArea(ext_cont) # Find child contours (holes) hole_idxs = [i for i, hr in enumerate(hierarchy) if hr[3] == idx] # Sort holes by area descending and keep up to max_n_holes sorted_holes = sorted( [contours[i] for i in hole_idxs], key=cv2.contourArea, reverse=True ) selected_holes = [ hole for hole in sorted_holes[:max_n_holes] if cv2.contourArea(hole) > min_hole_area ] total_hole_area = sum(cv2.contourArea(hole) for hole in selected_holes) effective_area = ext_area - total_hole_area if effective_area > min_effective_area: filtered_contours.append(ext_cont) holes_list.append(selected_holes) return filtered_contours, holes_list def scale_contours(contours: list, scale: tuple) -> list: """ Scale contour coordinates by the provided scale factors. Args: contours: List of contours (each a numpy array of points). scale: Tuple (sx, sy) for scaling. Returns: List of scaled contours. """ scaled = [] for cont in contours: scaled.append((cont * np.array(scale, dtype=np.float32)).astype(np.int32)) return scaled def point_in_tissue(x: int, y: int, ext_contours: list, holes_list: list) -> bool: """ Check if point (x, y) lies within any external contour and not inside its corresponding holes. For each external contour in ext_contours (paired with holes_list), if the point is inside the contour and not inside any of its holes, return True. """ for cont, holes in zip(ext_contours, holes_list): if cv2.pointPolygonTest(cont, (x, y), False) >= 0: inside_hole = False for hole in holes: if cv2.pointPolygonTest(hole, (x, y), False) >= 0: inside_hole = True break if not inside_hole: return True return False def tile(x: torch.Tensor, size: int): C, H, W = x.shape[-3:] pad_h = (size - H % size) % size pad_w = (size - W % size) % size if pad_h > 0 or pad_w > 0: x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h)) nh, nw = x.size(2) // size, x.size(3) // size return ( x.view(-1, C, nh, size, nw, size) .permute(0, 2, 4, 1, 3, 5) .reshape(-1, C, size, size) )