File size: 5,653 Bytes
970a7a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import glob
import torch
from os.path import join
import numpy as np
from froc_by_pranjal import file_to_bbox, calc_froc_from_dict, pretty_print_fps
import sys
from ensemble_boxes import *
import json
import pickle
get_file_id = lambda x: x.split('_')[1]
get_acr_cat = lambda x: '0' if x not in acr_cat else acr_cat[x]
cat_to_idx = {'a':1,'b':2,'c':3,'d':4}
def get_image_dict(dataset_paths, labels = ['mal','ben'], allowed = [], USE_ACR = False, acr_cat = None, mp_dict = None):
image_dict = dict()
if allowed == []:
allowed = [i for i in range(len(dataset_paths))]
for label in labels:
images = list(set.intersection(*map(set, [os.listdir(dset.format(label)) for dset in dataset_paths])))
for image in images:
if USE_ACR:
acr = get_acr_cat(get_file_id(image))
# print(acr, image)
key = image[:-4]
gts = []
preds = []
for i,dset in enumerate(dataset_paths):
if i not in allowed:
continue
if USE_ACR:
if dset.find('AIIMS_C')!=-1:
if acr == '0': continue
if dset.find(f'AIIMS_C{cat_to_idx[acr]}') == -1:
continue
# Now choose dset to be the acr category one
dset = dset.replace('/test',f'/test_{acr}')
# print('ds',dset)
pred_file = join(dset.format(label), key+'.txt')
gt_file = join(os.path.split(dset.format(label))[0],'gt', key+'.txt')
if label == 'mal':
gts.append(file_to_bbox(gt_file))
else:
gts.append([])
# TODO: Note this
flag = False
for mp in mp_dict:
if dataset_paths[i].find(mp) != -1:
preds.append(mp_dict[mp](file_to_bbox(pred_file)))
flag = True
break
if not flag:
preds.append(file_to_bbox(pred_file))
# Ensure all gts are same
gt = gts[0]
for g in gts[1:]:
assert g == gt
gt = g
# Flatten Preds
preds = [np.array(p) for p in preds]
preds = [np.array([[0.,0.,0.,0.,0.]]) if pred.shape==(0,) else pred for pred in preds]
preds = [np.vstack((p, np.zeros((100 - len(p), 5)))) for p in preds]
image_dict[key] = dict()
image_dict[key]['gt'] = gts[0]
image_dict[key]['preds'] = preds
return image_dict
def apply_merge(image_dict, METHOD = 'wbf', weights = None, conf_type = None):
FACTOR = 5000
fusion_func = weighted_boxes_fusion if METHOD == 'wbf' else non_maximum_weighted
for key in image_dict:
preds = np.array(image_dict[key]['preds'])
if len(preds) != 0:
boxes_list = [pred[:,1:]/FACTOR for pred in preds]
scores_list = [pred[:,0] for pred in preds]
labels = [[0. for _ in range(len(p))] for p in preds]
if weights is None:
weights = [1 for _ in range(len(preds))]
if METHOD == 'wbf' and conf_type is not None:
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5, conf_type = conf_type)
else:
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5,)
preds_t = [[scores[i],FACTOR*boxes[i][0],FACTOR*boxes[i][1],FACTOR*boxes[i][2],FACTOR*boxes[i][3]] for i in range(len(boxes))]
image_dict[key]['preds'] = preds_t
return image_dict
def manipulate_preds(preds):
return preds
def manipulate_preds_4(preds):
return preds
tot = 0
def manipulate_preds_t1(preds): #return manipulate_preds(preds)
preds = list(filter(lambda x: x[0]>0.6,preds))
return preds
def manipulate_preds_t2(preds): return manipulate_preds_t1(preds)
if __name__ == '__main__':
USE_ACR = False
dataset_paths = [
'MammoDatasets/AIIMS_C1/test/{0}/preds_frcnn_AIIMS_C1',
'MammoDatasets/AIIMS_C2/test/{0}/preds_frcnn_AIIMS_C2',
'MammoDatasets/AIIMS_C3/test/{0}/preds_frcnn_AIIMS_C3',
'MammoDatasets/AIIMS_C4/test/{0}/preds_frcnn_AIIMS_C4',
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_bilateral_BILATERAL',
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_frcnn_16',
]
st = int(sys.argv[1])
end = len(dataset_paths) - int(sys.argv[2])
allowed = [i for i in range(st,end)]
allowed = [0,1,2,3,4,5]
OUT_FILE = 'contrast_frcnn.txt'
if OUT_FILE is not None:
fol = os.path.split(OUT_FILE)[0]
if fol != '':
os.makedirs(fol, exist_ok=True)
acr_cat = json.load(open('aiims_categories.json','r'))
print(allowed)
mp_dict = {
'preds_frcnn_AIIMS_C3': manipulate_preds,
'preds_frcnn_AIIMS_C4': manipulate_preds_4,
'AIIMS_T2': manipulate_preds_t2,
'AIIMS_T1': manipulate_preds_t1,
}
image_dict = get_image_dict(dataset_paths, allowed = allowed, USE_ACR = USE_ACR, acr_cat = acr_cat, mp_dict = mp_dict)
image_dict = apply_merge(image_dict, METHOD = 'nms') # or wbf
if OUT_FILE:
pickle.dump(image_dict, open(OUT_FILE.replace('.txt','.pkl'),'wb'))
senses, fps = calc_froc_from_dict(image_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.],save_to=OUT_FILE)
pretty_print_fps(senses, fps)
|