Spaces:
Build error
Build error
| import os | |
| import glob | |
| import torch | |
| from os.path import join | |
| import numpy as np | |
| from froc_by_pranjal import file_to_bbox, calc_froc_from_dict, pretty_print_fps | |
| import sys | |
| from ensemble_boxes import * | |
| import json | |
| import pickle | |
| get_file_id = lambda x: x.split('_')[1] | |
| get_acr_cat = lambda x: '0' if x not in acr_cat else acr_cat[x] | |
| cat_to_idx = {'a':1,'b':2,'c':3,'d':4} | |
| def get_image_dict(dataset_paths, labels = ['mal','ben'], allowed = [], USE_ACR = False, acr_cat = None, mp_dict = None): | |
| image_dict = dict() | |
| if allowed == []: | |
| allowed = [i for i in range(len(dataset_paths))] | |
| for label in labels: | |
| images = list(set.intersection(*map(set, [os.listdir(dset.format(label)) for dset in dataset_paths]))) | |
| for image in images: | |
| if USE_ACR: | |
| acr = get_acr_cat(get_file_id(image)) | |
| # print(acr, image) | |
| key = image[:-4] | |
| gts = [] | |
| preds = [] | |
| for i,dset in enumerate(dataset_paths): | |
| if i not in allowed: | |
| continue | |
| if USE_ACR: | |
| if dset.find('AIIMS_C')!=-1: | |
| if acr == '0': continue | |
| if dset.find(f'AIIMS_C{cat_to_idx[acr]}') == -1: | |
| continue | |
| # Now choose dset to be the acr category one | |
| dset = dset.replace('/test',f'/test_{acr}') | |
| # print('ds',dset) | |
| pred_file = join(dset.format(label), key+'.txt') | |
| gt_file = join(os.path.split(dset.format(label))[0],'gt', key+'.txt') | |
| if label == 'mal': | |
| gts.append(file_to_bbox(gt_file)) | |
| else: | |
| gts.append([]) | |
| # TODO: Note this | |
| flag = False | |
| for mp in mp_dict: | |
| if dataset_paths[i].find(mp) != -1: | |
| preds.append(mp_dict[mp](file_to_bbox(pred_file))) | |
| flag = True | |
| break | |
| if not flag: | |
| preds.append(file_to_bbox(pred_file)) | |
| # Ensure all gts are same | |
| gt = gts[0] | |
| for g in gts[1:]: | |
| assert g == gt | |
| gt = g | |
| # Flatten Preds | |
| preds = [np.array(p) for p in preds] | |
| preds = [np.array([[0.,0.,0.,0.,0.]]) if pred.shape==(0,) else pred for pred in preds] | |
| preds = [np.vstack((p, np.zeros((100 - len(p), 5)))) for p in preds] | |
| image_dict[key] = dict() | |
| image_dict[key]['gt'] = gts[0] | |
| image_dict[key]['preds'] = preds | |
| return image_dict | |
| def apply_merge(image_dict, METHOD = 'wbf', weights = None, conf_type = None): | |
| FACTOR = 5000 | |
| fusion_func = weighted_boxes_fusion if METHOD == 'wbf' else non_maximum_weighted | |
| for key in image_dict: | |
| preds = np.array(image_dict[key]['preds']) | |
| if len(preds) != 0: | |
| boxes_list = [pred[:,1:]/FACTOR for pred in preds] | |
| scores_list = [pred[:,0] for pred in preds] | |
| labels = [[0. for _ in range(len(p))] for p in preds] | |
| if weights is None: | |
| weights = [1 for _ in range(len(preds))] | |
| if METHOD == 'wbf' and conf_type is not None: | |
| boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5, conf_type = conf_type) | |
| else: | |
| boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5,) | |
| preds_t = [[scores[i],FACTOR*boxes[i][0],FACTOR*boxes[i][1],FACTOR*boxes[i][2],FACTOR*boxes[i][3]] for i in range(len(boxes))] | |
| image_dict[key]['preds'] = preds_t | |
| return image_dict | |
| def manipulate_preds(preds): | |
| return preds | |
| def manipulate_preds_4(preds): | |
| return preds | |
| tot = 0 | |
| def manipulate_preds_t1(preds): #return manipulate_preds(preds) | |
| preds = list(filter(lambda x: x[0]>0.6,preds)) | |
| return preds | |
| def manipulate_preds_t2(preds): return manipulate_preds_t1(preds) | |
| if __name__ == '__main__': | |
| USE_ACR = False | |
| dataset_paths = [ | |
| 'MammoDatasets/AIIMS_C1/test/{0}/preds_frcnn_AIIMS_C1', | |
| 'MammoDatasets/AIIMS_C2/test/{0}/preds_frcnn_AIIMS_C2', | |
| 'MammoDatasets/AIIMS_C3/test/{0}/preds_frcnn_AIIMS_C3', | |
| 'MammoDatasets/AIIMS_C4/test/{0}/preds_frcnn_AIIMS_C4', | |
| 'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_bilateral_BILATERAL', | |
| 'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_frcnn_16', | |
| ] | |
| st = int(sys.argv[1]) | |
| end = len(dataset_paths) - int(sys.argv[2]) | |
| allowed = [i for i in range(st,end)] | |
| allowed = [0,1,2,3,4,5] | |
| OUT_FILE = 'contrast_frcnn.txt' | |
| if OUT_FILE is not None: | |
| fol = os.path.split(OUT_FILE)[0] | |
| if fol != '': | |
| os.makedirs(fol, exist_ok=True) | |
| acr_cat = json.load(open('aiims_categories.json','r')) | |
| print(allowed) | |
| mp_dict = { | |
| 'preds_frcnn_AIIMS_C3': manipulate_preds, | |
| 'preds_frcnn_AIIMS_C4': manipulate_preds_4, | |
| 'AIIMS_T2': manipulate_preds_t2, | |
| 'AIIMS_T1': manipulate_preds_t1, | |
| } | |
| image_dict = get_image_dict(dataset_paths, allowed = allowed, USE_ACR = USE_ACR, acr_cat = acr_cat, mp_dict = mp_dict) | |
| image_dict = apply_merge(image_dict, METHOD = 'nms') # or wbf | |
| if OUT_FILE: | |
| pickle.dump(image_dict, open(OUT_FILE.replace('.txt','.pkl'),'wb')) | |
| senses, fps = calc_froc_from_dict(image_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.],save_to=OUT_FILE) | |
| pretty_print_fps(senses, fps) | |