편명장/님/(myeongjang.pyeon)
initial commit
287a683
import cv2
import numpy as np
import torch
from openslide import OpenSlide
def extract_tissue_patch_coords(
wsi_path: str,
patch_size: int = 256,
step_size: int = 256,
downsample_threshold: float = 64,
threshold: int = 8,
max_val: int = 255,
median_kernel: int = 7,
close_size: int = 4,
min_effective_area_factor: float = 100, # multiplied by (ref_area)^2
ref_area: int = 512,
min_hole_area_factor: float = 16, # multiplied by (ref_area)^2
max_n_holes: int = 8,
)-> list:
"""
Extract patches from the full-resolution image whose centers fall within tissue regions.
Process:
1. Open the WSI.
2. Select a segmentation level and compute a binary mask.
3. Find contours and holes from the mask and filter them using effective area criteria.
4. Scale the external contours and holes to full resolution.
5. Slide a window over the full-resolution image and extract patches if the center is in tissue.
Returns:
A torch tensor of shape (N, 3, patch_size, patch_size) containing the patches.
"""
slide = OpenSlide(wsi_path)
full_width, full_height = slide.level_dimensions[0]
seg_level, scale = select_segmentation_level(slide, downsample_threshold)
binary_mask = compute_segmentation_mask(
slide, seg_level, threshold, max_val, median_kernel, close_size
)
# Compute thresholds for effective area and hole area
effective_area_thresh = min_effective_area_factor * (
ref_area**2 / (scale[0] * scale[1])
)
hole_area_thresh = min_hole_area_factor * (ref_area**2 / (scale[0] * scale[1]))
ext_contours, holes_list = filter_contours_and_holes(
binary_mask, effective_area_thresh, hole_area_thresh, max_n_holes
)
if not ext_contours:
raise ValueError("No valid tissue contours found.")
tissue_contours = scale_contours(ext_contours, scale)
scaled_holes = [scale_contours(holes, scale) for holes in holes_list]
coords = []
for y in range(0, full_height - patch_size + 1, step_size):
for x in range(0, full_width - patch_size + 1, step_size):
center_x = x + patch_size // 2
center_y = y + patch_size // 2
if not point_in_tissue(center_x, center_y, tissue_contours, scaled_holes):
continue
coords.append((x, y))
if not coords:
raise ValueError("No available patches")
return coords
def select_segmentation_level(slide: OpenSlide, downsample_threshold: float = 64):
"""
Select a segmentation level whose downsample factor is at least the specified threshold.
Returns:
level (int): Chosen level index.
scale (tuple): Downsample factors (sx, sy) for that level.
"""
level = slide.get_best_level_for_downsample(downsample_threshold)
ds = slide.level_downsamples[level]
if not isinstance(ds, (tuple, list)):
ds = (ds, ds)
return level, ds
def compute_segmentation_mask(
slide: OpenSlide,
level: int,
threshold: int = 20,
max_val: int = 255,
median_kernel: int = 7,
close_size: int = 4,
):
"""
Compute a binary mask for tissue segmentation at the specified level.
Process:
- Read the image at the given level and convert to RGB.
- Convert the image to HSV and extract the saturation channel.
- Apply median blur.
- Apply binary thresholding (either fixed or Otsu).
- Apply morphological closing.
Returns:
binary (ndarray): Binary mask image.
"""
img = np.array(
slide.read_region((0, 0), level, slide.level_dimensions[level]).convert("RGB")
)
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
sat = hsv[:, :, 1]
blurred = cv2.medianBlur(sat, median_kernel)
_, binary = cv2.threshold(blurred, threshold, max_val, cv2.THRESH_BINARY)
if close_size > 0:
kernel = np.ones((close_size, close_size), np.uint8)
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
return binary
def filter_contours_and_holes(
binary_mask: np.ndarray,
min_effective_area: float,
min_hole_area: float,
max_n_holes: int,
):
"""
Find contours from the binary mask and filter them based on effective area.
For each external contour (one with no parent), identify child contours (holes),
sort them by area (largest first), and keep up to max_n_holes that exceed min_hole_area.
The effective area is computed as the area of the external contour minus the sum of areas
of the selected holes. Only contours with effective area above min_effective_area are retained.
Returns:
filtered_contours (list): List of external contours (numpy arrays).
holes_list (list): Corresponding list of lists of hole contours.
"""
contours, hierarchy = cv2.findContours(
binary_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
)
if hierarchy is None:
return [], []
hierarchy = hierarchy[0] # shape: (N, 4)
filtered_contours = []
holes_list = []
for idx, h in enumerate(hierarchy):
if h[3] != -1:
continue # Only external contours
ext_cont = contours[idx]
ext_area = cv2.contourArea(ext_cont)
# Find child contours (holes)
hole_idxs = [i for i, hr in enumerate(hierarchy) if hr[3] == idx]
# Sort holes by area descending and keep up to max_n_holes
sorted_holes = sorted(
[contours[i] for i in hole_idxs], key=cv2.contourArea, reverse=True
)
selected_holes = [
hole
for hole in sorted_holes[:max_n_holes]
if cv2.contourArea(hole) > min_hole_area
]
total_hole_area = sum(cv2.contourArea(hole) for hole in selected_holes)
effective_area = ext_area - total_hole_area
if effective_area > min_effective_area:
filtered_contours.append(ext_cont)
holes_list.append(selected_holes)
return filtered_contours, holes_list
def scale_contours(contours: list, scale: tuple) -> list:
"""
Scale contour coordinates by the provided scale factors.
Args:
contours: List of contours (each a numpy array of points).
scale: Tuple (sx, sy) for scaling.
Returns:
List of scaled contours.
"""
scaled = []
for cont in contours:
scaled.append((cont * np.array(scale, dtype=np.float32)).astype(np.int32))
return scaled
def point_in_tissue(x: int, y: int, ext_contours: list, holes_list: list) -> bool:
"""
Check if point (x, y) lies within any external contour and not inside its corresponding holes.
For each external contour in ext_contours (paired with holes_list),
if the point is inside the contour and not inside any of its holes, return True.
"""
for cont, holes in zip(ext_contours, holes_list):
if cv2.pointPolygonTest(cont, (x, y), False) >= 0:
inside_hole = False
for hole in holes:
if cv2.pointPolygonTest(hole, (x, y), False) >= 0:
inside_hole = True
break
if not inside_hole:
return True
return False
def tile(x: torch.Tensor, size: int):
C, H, W = x.shape[-3:]
pad_h = (size - H % size) % size
pad_w = (size - W % size) % size
if pad_h > 0 or pad_w > 0:
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))
nh, nw = x.size(2) // size, x.size(3) // size
return (
x.view(-1, C, nh, size, nw, size)
.permute(0, 2, 4, 1, 3, 5)
.reshape(-1, C, size, size)
)