|
import os |
|
import glob |
|
import torch |
|
from os.path import join |
|
import numpy as np |
|
from froc_by_pranjal import file_to_bbox, calc_froc_from_dict, pretty_print_fps |
|
import sys |
|
from ensemble_boxes import * |
|
import json |
|
import pickle |
|
|
|
|
|
|
|
get_file_id = lambda x: x.split('_')[1] |
|
get_acr_cat = lambda x: '0' if x not in acr_cat else acr_cat[x] |
|
cat_to_idx = {'a':1,'b':2,'c':3,'d':4} |
|
|
|
|
|
def get_image_dict(dataset_paths, labels = ['mal','ben'], allowed = [], USE_ACR = False, acr_cat = None, mp_dict = None): |
|
image_dict = dict() |
|
if allowed == []: |
|
allowed = [i for i in range(len(dataset_paths))] |
|
for label in labels: |
|
images = list(set.intersection(*map(set, [os.listdir(dset.format(label)) for dset in dataset_paths]))) |
|
for image in images: |
|
if USE_ACR: |
|
acr = get_acr_cat(get_file_id(image)) |
|
|
|
key = image[:-4] |
|
gts = [] |
|
preds = [] |
|
for i,dset in enumerate(dataset_paths): |
|
if i not in allowed: |
|
continue |
|
if USE_ACR: |
|
if dset.find('AIIMS_C')!=-1: |
|
if acr == '0': continue |
|
if dset.find(f'AIIMS_C{cat_to_idx[acr]}') == -1: |
|
continue |
|
|
|
dset = dset.replace('/test',f'/test_{acr}') |
|
|
|
pred_file = join(dset.format(label), key+'.txt') |
|
gt_file = join(os.path.split(dset.format(label))[0],'gt', key+'.txt') |
|
if label == 'mal': |
|
gts.append(file_to_bbox(gt_file)) |
|
else: |
|
gts.append([]) |
|
|
|
|
|
flag = False |
|
for mp in mp_dict: |
|
if dataset_paths[i].find(mp) != -1: |
|
preds.append(mp_dict[mp](file_to_bbox(pred_file))) |
|
flag = True |
|
break |
|
if not flag: |
|
preds.append(file_to_bbox(pred_file)) |
|
|
|
|
|
gt = gts[0] |
|
for g in gts[1:]: |
|
assert g == gt |
|
gt = g |
|
|
|
|
|
preds = [np.array(p) for p in preds] |
|
preds = [np.array([[0.,0.,0.,0.,0.]]) if pred.shape==(0,) else pred for pred in preds] |
|
preds = [np.vstack((p, np.zeros((100 - len(p), 5)))) for p in preds] |
|
image_dict[key] = dict() |
|
image_dict[key]['gt'] = gts[0] |
|
image_dict[key]['preds'] = preds |
|
return image_dict |
|
|
|
|
|
def apply_merge(image_dict, METHOD = 'wbf', weights = None, conf_type = None): |
|
FACTOR = 5000 |
|
fusion_func = weighted_boxes_fusion if METHOD == 'wbf' else non_maximum_weighted |
|
for key in image_dict: |
|
preds = np.array(image_dict[key]['preds']) |
|
if len(preds) != 0: |
|
boxes_list = [pred[:,1:]/FACTOR for pred in preds] |
|
scores_list = [pred[:,0] for pred in preds] |
|
labels = [[0. for _ in range(len(p))] for p in preds] |
|
if weights is None: |
|
weights = [1 for _ in range(len(preds))] |
|
if METHOD == 'wbf' and conf_type is not None: |
|
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5, conf_type = conf_type) |
|
else: |
|
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5,) |
|
preds_t = [[scores[i],FACTOR*boxes[i][0],FACTOR*boxes[i][1],FACTOR*boxes[i][2],FACTOR*boxes[i][3]] for i in range(len(boxes))] |
|
image_dict[key]['preds'] = preds_t |
|
return image_dict |
|
|
|
def manipulate_preds(preds): |
|
return preds |
|
|
|
|
|
|
|
def manipulate_preds_4(preds): |
|
return preds |
|
|
|
tot = 0 |
|
def manipulate_preds_t1(preds): |
|
preds = list(filter(lambda x: x[0]>0.6,preds)) |
|
|
|
return preds |
|
|
|
def manipulate_preds_t2(preds): return manipulate_preds_t1(preds) |
|
|
|
|
|
if __name__ == '__main__': |
|
USE_ACR = False |
|
dataset_paths = [ |
|
'MammoDatasets/AIIMS_C1/test/{0}/preds_frcnn_AIIMS_C1', |
|
'MammoDatasets/AIIMS_C2/test/{0}/preds_frcnn_AIIMS_C2', |
|
'MammoDatasets/AIIMS_C3/test/{0}/preds_frcnn_AIIMS_C3', |
|
'MammoDatasets/AIIMS_C4/test/{0}/preds_frcnn_AIIMS_C4', |
|
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_bilateral_BILATERAL', |
|
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_frcnn_16', |
|
] |
|
|
|
|
|
st = int(sys.argv[1]) |
|
end = len(dataset_paths) - int(sys.argv[2]) |
|
allowed = [i for i in range(st,end)] |
|
allowed = [0,1,2,3,4,5] |
|
|
|
OUT_FILE = 'contrast_frcnn.txt' |
|
if OUT_FILE is not None: |
|
fol = os.path.split(OUT_FILE)[0] |
|
if fol != '': |
|
os.makedirs(fol, exist_ok=True) |
|
|
|
acr_cat = json.load(open('aiims_categories.json','r')) |
|
print(allowed) |
|
|
|
mp_dict = { |
|
'preds_frcnn_AIIMS_C3': manipulate_preds, |
|
'preds_frcnn_AIIMS_C4': manipulate_preds_4, |
|
'AIIMS_T2': manipulate_preds_t2, |
|
'AIIMS_T1': manipulate_preds_t1, |
|
} |
|
|
|
image_dict = get_image_dict(dataset_paths, allowed = allowed, USE_ACR = USE_ACR, acr_cat = acr_cat, mp_dict = mp_dict) |
|
|
|
image_dict = apply_merge(image_dict, METHOD = 'nms') |
|
|
|
if OUT_FILE: |
|
pickle.dump(image_dict, open(OUT_FILE.replace('.txt','.pkl'),'wb')) |
|
senses, fps = calc_froc_from_dict(image_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.],save_to=OUT_FILE) |
|
pretty_print_fps(senses, fps) |
|
|