Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| # -*- coding: utf-8 -*- | |
| # Implemented Metrics for Cell detection | |
| # | |
| # This code is based on the following repository: https://github.com/TissueImageAnalytics/PanNuke-metrics | |
| # | |
| # Implemented metrics are: | |
| # | |
| # Instance Segmentation Metrics | |
| # Binary PQ | |
| # Multiclass PQ | |
| # Neoplastic PQ | |
| # Non-Neoplastic PQ | |
| # Inflammatory PQ | |
| # Dead PQ | |
| # Inflammatory PQ | |
| # Dead PQ | |
| # | |
| # Detection and Classification Metrics | |
| # Precision, Recall, F1 | |
| # | |
| # Other | |
| # dice1, dice2, aji, aji_plus | |
| # | |
| # Binary PQ (bPQ): Assumes all nuclei belong to same class and reports the average PQ across tissue types. | |
| # Multi-Class PQ (mPQ): Reports the average PQ across the classes and tissue types. | |
| # Neoplastic PQ: Reports the PQ for the neoplastic class on all tissues. | |
| # Non-Neoplastic PQ: Reports the PQ for the non-neoplastic class on all tissues. | |
| # Inflammatory PQ: Reports the PQ for the inflammatory class on all tissues. | |
| # Connective PQ: Reports the PQ for the connective class on all tissues. | |
| # Dead PQ: Reports the PQ for the dead class on all tissues. | |
| # | |
| # @ Fabian Hörst, [email protected] | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| from typing import List | |
| import numpy as np | |
| from scipy.optimize import linear_sum_assignment | |
| def get_fast_pq(true, pred, match_iou=0.5): | |
| """ | |
| `match_iou` is the IoU threshold level to determine the pairing between | |
| GT instances `p` and prediction instances `g`. `p` and `g` is a pair | |
| if IoU > `match_iou`. However, pair of `p` and `g` must be unique | |
| (1 prediction instance to 1 GT instance mapping). | |
| If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching | |
| in bipartite graphs) is caculated to find the maximal amount of unique pairing. | |
| If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and | |
| the number of pairs is also maximal. | |
| Fast computation requires instance IDs are in contiguous orderding | |
| i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand | |
| and `by_size` flag has no effect on the result. | |
| Returns: | |
| [dq, sq, pq]: measurement statistic | |
| [paired_true, paired_pred, unpaired_true, unpaired_pred]: | |
| pairing information to perform measurement | |
| """ | |
| assert match_iou >= 0.0, "Cant' be negative" | |
| true = np.copy(true) #[256,256] | |
| pred = np.copy(pred) #(256,256) #pred是预测的mask | |
| true_id_list = list(np.unique(true)) | |
| pred_id_list = list(np.unique(pred)) #pred_id_list是预测的mask的id | |
| # if there is no background, fixing by adding it | |
| if 0 not in pred_id_list: | |
| pred_id_list = [0] + pred_id_list | |
| true_masks = [ | |
| None, | |
| ] | |
| for t in true_id_list[1:]: #t最大8 | |
| t_mask = np.array(true == t, np.uint8) | |
| true_masks.append(t_mask) #true_masks是真实的mask true_masks[1].shape =[256,256] | |
| pred_masks = [ | |
| None, | |
| ] | |
| for p in pred_id_list[1:]: #p最大9 | |
| p_mask = np.array(pred == p, np.uint8) | |
| pred_masks.append(p_mask) #pred_masks是预测的mask pred_masks[1].shape =[256,256] | |
| # prefill with value重新填充值 | |
| pairwise_iou = np.zeros( | |
| [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 | |
| ) | |
| # caching pairwise iou for all instances 为所有的实例缓存iou | |
| for true_id in true_id_list[1:]: # 0-th is background 0是背景 | |
| #import pdb; pdb.set_trace() | |
| t_mask = true_masks[true_id] # 256*256为true_id的mask,也就是找到正确的mask | |
| #import pdb; pdb.set_trace() | |
| pred_true_overlap = pred[t_mask > 0] # 256*256的mask中,找到预测的mask,这两者的交集也就是预测正确的mask,也就是说这个mask是正确的, | |
| #t_mask是真实的mask,pred[t_mask > 0]是预测的mask中的pred是用来找到预测的mask的,也就是说pred的形状和t_mask的形状是一样的 | |
| #import pdb; pdb.set_trace() | |
| pred_true_overlap_id = np.unique(pred_true_overlap) | |
| pred_true_overlap_id = list(pred_true_overlap_id) | |
| for pred_id in pred_true_overlap_id: | |
| if pred_id == 0: # ignore | |
| continue # overlaping background | |
| p_mask = pred_masks[pred_id] | |
| total = (t_mask + p_mask).sum() | |
| inter = (t_mask * p_mask).sum() | |
| iou = inter / (total - inter) | |
| pairwise_iou[true_id - 1, pred_id - 1] = iou | |
| # | |
| if match_iou >= 0.5: | |
| paired_iou = pairwise_iou[pairwise_iou > match_iou] | |
| pairwise_iou[pairwise_iou <= match_iou] = 0.0 | |
| paired_true, paired_pred = np.nonzero(pairwise_iou) | |
| paired_iou = pairwise_iou[paired_true, paired_pred] | |
| paired_true += 1 # index is instance id - 1 | |
| paired_pred += 1 # hence return back to original | |
| else: # * Exhaustive maximal unique pairing | |
| #### Munkres pairing with scipy library | |
| # the algorithm return (row indices, matched column indices) | |
| # if there is multiple same cost in a row, index of first occurence | |
| # is return, thus the unique pairing is ensure | |
| # inverse pair to get high IoU as minimum | |
| paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) | |
| ### extract the paired cost and remove invalid pair | |
| paired_iou = pairwise_iou[paired_true, paired_pred] | |
| # now select those above threshold level | |
| # paired with iou = 0.0 i.e no intersection => FP or FN | |
| paired_true = list(paired_true[paired_iou > match_iou] + 1) | |
| paired_pred = list(paired_pred[paired_iou > match_iou] + 1) | |
| paired_iou = paired_iou[paired_iou > match_iou] | |
| # get the actual FP and FN | |
| unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true] | |
| unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred] | |
| # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred)) | |
| # | |
| tp = len(paired_true) | |
| fp = len(unpaired_pred) | |
| fn = len(unpaired_true) | |
| # get the F1-score i.e DQ | |
| dq = tp / (tp + 0.5 * fp + 0.5 * fn + 1.0e-6) # good practice? | |
| # get the SQ, no paired has 0 iou so not impact | |
| sq = paired_iou.sum() / (tp + 1.0e-6) | |
| return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred] | |
| ##### | |
| def remap_label(pred, by_size=False): | |
| """ | |
| Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] | |
| not [0, 2, 4, 6]. The ordering of instances (which one comes first) | |
| is preserved unless by_size=True, then the instances will be reordered | |
| so that bigger nucler has smaller ID | |
| Args: | |
| pred : the 2d array contain instances where each instances is marked | |
| by non-zero integer | |
| by_size : renaming with larger nuclei has smaller id (on-top) | |
| """ | |
| pred_id = list(np.unique(pred)) | |
| if 0 in pred_id: | |
| pred_id.remove(0) | |
| if len(pred_id) == 0: | |
| return pred # no label | |
| if by_size: | |
| pred_size = [] | |
| for inst_id in pred_id: | |
| size = (pred == inst_id).sum() | |
| pred_size.append(size) | |
| # sort the id by size in descending order | |
| pair_list = zip(pred_id, pred_size) | |
| pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True) | |
| pred_id, pred_size = zip(*pair_list) | |
| new_pred = np.zeros(pred.shape, np.int32) | |
| for idx, inst_id in enumerate(pred_id): | |
| new_pred[pred == inst_id] = idx + 1 | |
| return new_pred | |
| #### | |
| def binarize(x): | |
| """ | |
| convert multichannel (multiclass) instance segmetation tensor | |
| to binary instance segmentation (bg and nuclei), | |
| :param x: B*B*C (for PanNuke 256*256*5 ) | |
| :return: Instance segmentation 这段代码的作用是将多通道的mask转换为单通道的mask | |
| """ | |
| #x = np.transpose(x, (1, 2, 0)) #[256,256,5] | |
| out = np.zeros([x.shape[0], x.shape[1]]) #首先为out赋值为0,形状为256*256 | |
| count = 1 | |
| for i in range(x.shape[2]): #遍历通道数 | |
| x_ch = x[:, :, i] #[256,256] #取出每个通道的mask 形状为256*256 | |
| unique_vals = np.unique(x_ch) #找到每个通道的mask中的唯一值,形状为(1,) | |
| unique_vals = unique_vals.tolist() #将unique_vals转换为list | |
| unique_vals.remove(0) #移除0 | |
| for j in unique_vals: #遍历unique_vals,也就是遍历每个通道的mask中的唯一值 | |
| x_tmp = x_ch == j #找到每个通道的mask中的唯一值的mask,在创建一个布尔类型的数组,其中元素为 True 的位置表示原始数组 x_ch 中对应位置的元素等于 j,元素为 False 的位置表示不等于 j | |
| x_tmp_c = 1 - x_tmp #找到每个通道的mask中的唯一值的mask的补集 | |
| out *= x_tmp_c #将out中的值乘以x_tmp_c | |
| out += count * x_tmp #将out中的值加上count*x_tmp | |
| count += 1 | |
| out = out.astype("int32") | |
| return out | |
| def get_tissue_idx(tissue_indices, idx): | |
| for i in range(len(tissue_indices)): | |
| if tissue_indices[i].count(idx) == 1: | |
| tiss_idx = i | |
| return tiss_idx | |
| def cell_detection_scores( | |
| paired_true, paired_pred, unpaired_true, unpaired_pred, w: List = [1, 1] | |
| ): | |
| tp_d = paired_pred.shape[0] | |
| fp_d = unpaired_pred.shape[0] | |
| fn_d = unpaired_true.shape[0] | |
| # tp_tn_dt = (paired_pred == paired_true).sum() | |
| # fp_fn_dt = (paired_pred != paired_true).sum() | |
| prec_d = tp_d / (tp_d + fp_d) | |
| rec_d = tp_d / (tp_d + fn_d) | |
| f1_d = 2 * tp_d / (2 * tp_d + w[0] * fp_d + w[1] * fn_d) | |
| return f1_d, prec_d, rec_d | |
| def cell_type_detection_scores( | |
| paired_true, | |
| paired_pred, | |
| unpaired_true, | |
| unpaired_pred, | |
| type_id, | |
| w: List = [2, 2, 1, 1], | |
| exhaustive: bool = True, | |
| ): | |
| type_samples = (paired_true == type_id) | (paired_pred == type_id) | |
| paired_true = paired_true[type_samples] | |
| paired_pred = paired_pred[type_samples] | |
| tp_dt = ((paired_true == type_id) & (paired_pred == type_id)).sum() | |
| tn_dt = ((paired_true != type_id) & (paired_pred != type_id)).sum() | |
| fp_dt = ((paired_true != type_id) & (paired_pred == type_id)).sum() | |
| fn_dt = ((paired_true == type_id) & (paired_pred != type_id)).sum() | |
| if not exhaustive: | |
| ignore = (paired_true == -1).sum() | |
| fp_dt -= ignore | |
| fp_d = (unpaired_pred == type_id).sum() # | |
| fn_d = (unpaired_true == type_id).sum() | |
| prec_type = (tp_dt + tn_dt) / (tp_dt + tn_dt + w[0] * fp_dt + w[2] * fp_d) | |
| rec_type = (tp_dt + tn_dt) / (tp_dt + tn_dt + w[1] * fn_dt + w[3] * fn_d) | |
| f1_type = (2 * (tp_dt + tn_dt)) / ( | |
| 2 * (tp_dt + tn_dt) + w[0] * fp_dt + w[1] * fn_dt + w[2] * fp_d + w[3] * fn_d | |
| ) | |
| return f1_type, prec_type, rec_type | |
