Pranjal2041's picture
Initial demo
970a7a2
raw
history blame
5.65 kB
import os
import glob
import torch
from os.path import join
import numpy as np
from froc_by_pranjal import file_to_bbox, calc_froc_from_dict, pretty_print_fps
import sys
from ensemble_boxes import *
import json
import pickle
get_file_id = lambda x: x.split('_')[1]
get_acr_cat = lambda x: '0' if x not in acr_cat else acr_cat[x]
cat_to_idx = {'a':1,'b':2,'c':3,'d':4}
def get_image_dict(dataset_paths, labels = ['mal','ben'], allowed = [], USE_ACR = False, acr_cat = None, mp_dict = None):
image_dict = dict()
if allowed == []:
allowed = [i for i in range(len(dataset_paths))]
for label in labels:
images = list(set.intersection(*map(set, [os.listdir(dset.format(label)) for dset in dataset_paths])))
for image in images:
if USE_ACR:
acr = get_acr_cat(get_file_id(image))
# print(acr, image)
key = image[:-4]
gts = []
preds = []
for i,dset in enumerate(dataset_paths):
if i not in allowed:
continue
if USE_ACR:
if dset.find('AIIMS_C')!=-1:
if acr == '0': continue
if dset.find(f'AIIMS_C{cat_to_idx[acr]}') == -1:
continue
# Now choose dset to be the acr category one
dset = dset.replace('/test',f'/test_{acr}')
# print('ds',dset)
pred_file = join(dset.format(label), key+'.txt')
gt_file = join(os.path.split(dset.format(label))[0],'gt', key+'.txt')
if label == 'mal':
gts.append(file_to_bbox(gt_file))
else:
gts.append([])
# TODO: Note this
flag = False
for mp in mp_dict:
if dataset_paths[i].find(mp) != -1:
preds.append(mp_dict[mp](file_to_bbox(pred_file)))
flag = True
break
if not flag:
preds.append(file_to_bbox(pred_file))
# Ensure all gts are same
gt = gts[0]
for g in gts[1:]:
assert g == gt
gt = g
# Flatten Preds
preds = [np.array(p) for p in preds]
preds = [np.array([[0.,0.,0.,0.,0.]]) if pred.shape==(0,) else pred for pred in preds]
preds = [np.vstack((p, np.zeros((100 - len(p), 5)))) for p in preds]
image_dict[key] = dict()
image_dict[key]['gt'] = gts[0]
image_dict[key]['preds'] = preds
return image_dict
def apply_merge(image_dict, METHOD = 'wbf', weights = None, conf_type = None):
FACTOR = 5000
fusion_func = weighted_boxes_fusion if METHOD == 'wbf' else non_maximum_weighted
for key in image_dict:
preds = np.array(image_dict[key]['preds'])
if len(preds) != 0:
boxes_list = [pred[:,1:]/FACTOR for pred in preds]
scores_list = [pred[:,0] for pred in preds]
labels = [[0. for _ in range(len(p))] for p in preds]
if weights is None:
weights = [1 for _ in range(len(preds))]
if METHOD == 'wbf' and conf_type is not None:
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5, conf_type = conf_type)
else:
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5,)
preds_t = [[scores[i],FACTOR*boxes[i][0],FACTOR*boxes[i][1],FACTOR*boxes[i][2],FACTOR*boxes[i][3]] for i in range(len(boxes))]
image_dict[key]['preds'] = preds_t
return image_dict
def manipulate_preds(preds):
return preds
def manipulate_preds_4(preds):
return preds
tot = 0
def manipulate_preds_t1(preds): #return manipulate_preds(preds)
preds = list(filter(lambda x: x[0]>0.6,preds))
return preds
def manipulate_preds_t2(preds): return manipulate_preds_t1(preds)
if __name__ == '__main__':
USE_ACR = False
dataset_paths = [
'MammoDatasets/AIIMS_C1/test/{0}/preds_frcnn_AIIMS_C1',
'MammoDatasets/AIIMS_C2/test/{0}/preds_frcnn_AIIMS_C2',
'MammoDatasets/AIIMS_C3/test/{0}/preds_frcnn_AIIMS_C3',
'MammoDatasets/AIIMS_C4/test/{0}/preds_frcnn_AIIMS_C4',
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_bilateral_BILATERAL',
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_frcnn_16',
]
st = int(sys.argv[1])
end = len(dataset_paths) - int(sys.argv[2])
allowed = [i for i in range(st,end)]
allowed = [0,1,2,3,4,5]
OUT_FILE = 'contrast_frcnn.txt'
if OUT_FILE is not None:
fol = os.path.split(OUT_FILE)[0]
if fol != '':
os.makedirs(fol, exist_ok=True)
acr_cat = json.load(open('aiims_categories.json','r'))
print(allowed)
mp_dict = {
'preds_frcnn_AIIMS_C3': manipulate_preds,
'preds_frcnn_AIIMS_C4': manipulate_preds_4,
'AIIMS_T2': manipulate_preds_t2,
'AIIMS_T1': manipulate_preds_t1,
}
image_dict = get_image_dict(dataset_paths, allowed = allowed, USE_ACR = USE_ACR, acr_cat = acr_cat, mp_dict = mp_dict)
image_dict = apply_merge(image_dict, METHOD = 'nms') # or wbf
if OUT_FILE:
pickle.dump(image_dict, open(OUT_FILE.replace('.txt','.pkl'),'wb'))
senses, fps = calc_froc_from_dict(image_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.],save_to=OUT_FILE)
pretty_print_fps(senses, fps)