kerzel's picture
fix in clustering to operate on image
e360db8
"""
Before we can identify damage sites, we need to look for suitable regions in the image.
Typically, damage sites appear as dark regions in the image. Instead of simple thresholding, we use
a clustering approach to identify regions that belong together and form damage site candidates.
"""
import numpy as np
import scipy.ndimage as ndi
from scipy.spatial import KDTree
from sklearn.cluster import DBSCAN
import logging
from PIL import Image # ADDED: Import PIL for image type checking/conversion
def get_centroids(image, image_threshold=20, # Removed type hint np.ndarray as it can also be PIL.Image.Image initially
eps=1, min_samples=5, metric='euclidean',
min_size=20, fill_holes=False,
filter_close_centroids=False, filter_radius=50) -> list:
"""
Determine centroids of clusters corresponding to potential damage sites.
In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
to the list of (x,y) coordinates that is returned as the final list potential damage sites.
Sometimes, clusters may be found in very close proximity to each other, we can reject those to avoid
classifying the same event multiple times (which may distort our statistics).
DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
Args:
image: Input SEM image (PIL Image or NumPy array).
image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
filter_close_centroids (bool, optional): Filter cluster centroids within a given radius. Defaults to False
filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
Returns:
list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
"""
centroids = []
logging.info(f"get_centroids: Input image type: {type(image)}") # Added logging
# --- MINIMAL FIX START ---
# Convert PIL Image to NumPy array if necessary
if isinstance(image, Image.Image):
# Convert to grayscale first for thresholding, assuming it's a single-channel operation
if image.mode == 'RGB': # Handle RGB images by converting to grayscale 'L' mode
image_array = np.array(image.convert('L'))
logging.info("get_centroids: Converted RGB PIL Image to grayscale NumPy array.") # Added logging
else: # Handle other PIL modes (like 'L' for grayscale)
image_array = np.array(image)
logging.info("get_centroids: Converted PIL Image to NumPy array.") # Added logging
elif isinstance(image, np.ndarray):
# If it's already a NumPy array, ensure it's grayscale if it was multi-channel
if image.ndim == 3 and image.shape[2] in [3, 4]: # RGB or RGBA NumPy array
image_array = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale by averaging channels
logging.info("get_centroids: Converted multi-channel NumPy array to grayscale NumPy array.") # Added logging
else: # Assume it's already a suitable grayscale NumPy array
image_array = image
logging.info("get_centroids: Image is already a suitable NumPy array.") # Added logging
else:
logging.error("get_centroids: Unsupported image format received. Expected PIL Image or NumPy array.") # Added logging
raise ValueError("Unsupported image format. Expected PIL Image or NumPy array for thresholding.")
# apply the threshold to identify regions of "dark" pixels
# The result is a binary mask (true/false) whether a given pixel is above or below the threshold
# Now using 'image_array' instead of 'image'
cluster_candidates_mask = image_array < image_threshold
# --- MINIMAL FIX END ---
# sometimes the clusters have small holes in them, for example, individual pixels
# inside a region below the threshold. This may confuse the clustering algorith later on
# and we can use the following to fill these holes
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
# N.B. the algorith only works on binay data
if fill_holes:
cluster_candidates_mask = ndi.binary_fill_holes(cluster_candidates_mask)
# transform image format into a numpy array to pass on to DBSCAN clustering
cluster_candidates = np.asarray(cluster_candidates_mask).nonzero()
cluster_candidates = np.transpose(cluster_candidates)
# Handle case where no candidates are found after thresholding
if cluster_candidates.size == 0: # Added check for empty array
logging.warning("No cluster candidates found after thresholding. Returning empty centroids list.")
return []
# run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
# (e.g. they are too small, etc)
# For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric) # Used 'metric' parameter
dbscan.fit(cluster_candidates)
labels = dbscan.labels_
# Number of clusters in labels, ignoring noise if present.
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
n_noise = list(labels).count(-1)
logging.debug('# clusters {}, #noise {}'.format(n_clusters, n_noise))
# now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
# we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
# in the cluster candidates
for i in set(labels):
if i > -1: # Ensure it's not noise
# all points belonging to a given cluster
cluster_points = cluster_candidates[labels == i, :]
if len(cluster_points) > min_size:
x_mean = np.mean(cluster_points, axis=0)[0]
y_mean = np.mean(cluster_points, axis=0)[1]
centroids.append([x_mean, y_mean])
if filter_close_centroids and len(centroids) > 1: # Only filter if there's more than one centroid
proximity_tree = KDTree(centroids)
pairs = proximity_tree.query_pairs(filter_radius)
# Use a set to mark indices for removal to avoid modifying list during iteration
indices_to_remove = set()
for p1_idx, p2_idx in pairs:
# Decide which one to remove. For simplicity, remove the one with the higher index
# This ensures you don't try to remove an index that might have already been removed
indices_to_remove.add(max(p1_idx, p2_idx))
# Rebuild the centroids list, excluding the marked ones
filtered_centroids = [centroid for i, centroid in enumerate(centroids) if i not in indices_to_remove]
centroids = filtered_centroids
logging.info(f"Filtered {len(indices_to_remove)} close centroids. Remaining: {len(centroids)}")
return centroids