File size: 4,157 Bytes
970a7a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from os.path import join
import glob
from sklearn.metrics import roc_auc_score, roc_curve
import sys

def file_to_score(file):
    try:
        content = open(file, 'r').readlines()
        st = 0
        if len(content) == 0:
            # Empty File Should Return []
            return 0.
        if content[0].split()[0].isalpha():
            st = 1
        return max([float(line.split()[st]) for line in content])
    except FileNotFoundError:
        print(f'No Corresponding Box Found for file {file}, using [] as preds')
        return []
    except Exception as e:
        print('Some Error',e)
        return []

# Create the image dict
def generate_image_dict(preds_folder_name='preds_42',
                        root_fol='/home/krithika_1/densebreeast_datasets/AIIMS_C1',
                        mal_path=None, ben_path=None, gt_path=None,
                        mal_img_path = None, ben_img_path = None
                        ):

    mal_path = join(root_fol, mal_path) if mal_path else join(
        root_fol, 'mal', preds_folder_name)
    ben_path = join(root_fol, ben_path) if ben_path else join(
        root_fol, 'ben', preds_folder_name)
    mal_img_path = join(root_fol, mal_img_path) if mal_img_path else join(
        root_fol, 'mal', 'images')
    ben_img_path = join(root_fol, ben_img_path) if ben_img_path else join(
        root_fol, 'ben', 'images')
    gt_path = join(root_fol, gt_path) if gt_path else join(
        root_fol, 'mal', 'gt')


    '''
        image_dict structure:
            'image_name(without txt/png)' : {'gt' : [[...]], 'preds' : score}
    '''
    image_dict = dict()

    # GT Might be sightly different from images, therefore we will index gts based on
    # the images folder instead.
    for file in os.listdir(mal_img_path):
    # for file in glob.glob(join(gt_path, '*.txt')):
        if not file.endswith('.png'):
            continue
        file = file[:-4] + '.txt'
        file = join(gt_path, file)
        key = os.path.split(file)[-1][:-4]
        image_dict[key] = dict()
        image_dict[key]['gt'] = 1.
        image_dict[key]['preds'] = 0.

    for file in glob.glob(join(mal_path, '*.txt')):
        key = os.path.split(file)[-1][:-4]
        assert key in image_dict
        image_dict[key]['preds'] = file_to_score(file)

    for file in os.listdir(ben_img_path):
    # for file in glob.glob(join(ben_path, '*.txt')):
        if not file.endswith('.png'):
            continue

        file = file[:-4] + '.txt'
        file = join(ben_path, file)
        key = os.path.split(file)[-1][:-4]
        # if key == 'Calc-Test_P_00353_LEFT_CC' or key == 'Calc-Training_P_00600_LEFT_CC':
        #     continue
        if key in image_dict:
            print(key)
            print('SHIT')
            continue
        # assert key not in image_dict
        image_dict[key] = dict()
        image_dict[key]['preds'] = file_to_score(file)
        image_dict[key]['gt'] = 0.
    return image_dict

def get_auc_score_from_imdict(image_dict):
    keys = list(image_dict.keys())
    y = [image_dict[k]['gt']for k in keys]    
    preds = [image_dict[k]['preds']for k in keys]
    return roc_auc_score(y, preds)

def get_accuracy_from_imdict(image_dict, thresh = 0.3):
    keys = list(image_dict.keys())
    ys = [image_dict[k]['gt']for k in keys]    
    preds = [image_dict[k]['preds']for k in keys]
    acc = 0
    for y,pred in zip(ys,preds):
        if pred < thresh and y == 0.:
            acc+=1
        elif pred > thresh and y == 1.:
            acc+=1
    return acc/len(preds)


def get_auc_score(preds_image_folder, root_fol, retAcc = False, acc_thresh = 0.3):
    im_dict = generate_image_dict(preds_image_folder, root_fol = root_fol)
    if retAcc:
        return get_auc_score_from_imdict(im_dict), get_accuracy_from_imdict(im_dict, acc_thresh)
    else:
        return get_auc_score_from_imdict(im_dict)

if __name__ == '__main__':
    seed = '42' if len(sys.argv)== 1 else sys.argv[1]

    root_fol = '../bilateral_new/MammoDatasets/AIIMS_highres_reliable/test'

    auc_score = get_auc_score(f'preds_{seed}',root_fol)
    print(f'ROC AUC Score: {auc_score}')