File size: 10,811 Bytes
12a72cd
04cffe3
12a72cd
 
 
54e819a
12a72cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5302093
12a72cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5302093
 
12a72cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5302093
 
12a72cd
 
eed82ca
12a72cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5302093
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from fastai.vision.all import *
image_extensions.add('.webp')

class MultiCaReClassifier():

  def __init__(self, image_folder, models_root = 'MultiCaReClassifier/models', save_path = '', add_multiclass_columns = False):

    '''Class used to classify medical images considering their types (such as ultrasound or MRI), and the corresponding anatomical region and view (for radiology images only).
    image_folder (str): folder containing all the input images.
    models_root (str): folder containing the image classification models.
    save_path (str): path to save the inference table.
    add_multiclass_columns (bool): if True, multiclass columns will be added to the dataframe based on the multilabel column ('label_list').'''

    self.image_folder = os.path.join(image_folder, '')
    self.models_root = models_root
    self.save_path = save_path
    self.add_multiclass_columns = add_multiclass_columns

    # List of possible labels per model.
    self.label_dict = {
        "image_type:radiology~anatomical_region:axial_region": ["abdomen", "breast", "head", "neck", "pelvis", "thorax"],
        "image_type:radiology~anatomical_region:lower_limb": ["ankle", "foot", "hip", "knee", "lower_leg", "thigh"],
        "image_type:radiology~anatomical_view": ["axial", "frontal", "intravascular", "oblique", "occlusal", "panoramic", "periapical", "sagittal", "transabdominal", "transesophageal", "transthoracic", "transvaginal"],
        "image_type:endoscopy": ["airway_endoscopy", "arthroscopy", "ig_endoscopy", "other_endoscopy"],
        "image_type:electrography": ["eeg", "ekg", "emg"],
        "image_type:ophthalmic_imaging": ["autofluorescence", "b_scan", "fundus_photograph", "gonioscopy", "oct", "ophtalmic_angiography", "slit_lamp_photograph"],
        "image_type:radiology~anatomical_region:upper_limb": ["elbow", "forearm", "hand", "shoulder", "upper_arm", "wrist"],
        "image_type:radiology~anatomical_region": ["axial_region", "lower_limb", "upper_limb", "whole_body"],
        "image_type:radiology~main": ["ct", "mri", "pet", "scintigraphy", "spect", "tractography", "ultrasound", "x_ray"],
        "image_type:pathology": ["acid_fast", "alcian_blue", "congo_red", "fish", "giemsa", "gram", "h&e", "immunostaining", "masson_trichrome", "methenamine_silver", "methylene_blue", "papanicolaou", "pas", "van_gieson"],
        "image_type:radiology~anatomical_region:axial_region.thorax": ["cardiac_image", "other_thoracic_image"],
        "image_type:medical_photograph": ["oral_photograph", "other_medical_photograph", "skin_photograph"],
        "image_type": ["chart", "electrography", "endoscopy", "medical_photograph", "ophthalmic_imaging", "pathology", "radiology"]
        }

    # The outcome dataframe is created.
    self.image_paths = get_image_files(self.image_folder)
    self.data = pd.DataFrame(columns=[name for name in self.label_dict.keys() if os.path.isdir(os.path.join('models', name.replace(':', '_')))])
    self.data['image_path'] = self.image_paths
    self.predict_image_classes()

  ### Main Methods ###

  def predict_image_classes(self):

    '''Method used to get the predictions for each image.'''

    # Models are ran one level of hierarchy at a time.
    model_order = 1
    while True:
      order_count = 0
      for model_name in self.label_dict.keys():
        if len(re.split(r'[:.]', model_name)) == model_order: # The count of ':' and '.' in a model name is equivalent to its level of hierarchy.
          self._add_predictions(model_name)
          order_count += 1
      if order_count == 0: # If a level of hierarchy is empty, then the process is finished.
        break
      model_order += 1
    
    # Postprocessing is applied.
    self.apply_postprocessing()
    if self.save_path:
      self.data.to_csv(self.save_path, index=None)

  def apply_postprocessing(self):

    '''Method used to postprocess the predictions.'''

    # All predictions are merged in a single column as a list.
    columns_to_flatten = [c for c in self.data.columns if c.startswith('image_type')]
    self.data['label_list'] = self.data[columns_to_flatten].values.tolist()
    self.data['label_list'] = self.data['label_list'].apply(lambda x: [item for item in x if isinstance(item, (str, np.str_))])
    self.data.drop(columns_to_flatten, axis = 1, inplace = True)

    # Class typos are fixed and certain classes with low accuracy are merged.
    replacement_dict = {'transesophageal': 'ultrasound_view', 'transthoracic': 'ultrasound_view', 'transabdominal': 'ultrasound_view',
                        'transvaginal': 'ultrasound_view', 'ophtalmic_angiography': 'ophthalmic_angiography', 'ig_endoscopy': 'gi_endoscopy'}

    self.data['label_list'] = self.data['label_list'].apply(lambda x: [replacement_dict.get(item, item) for item in x])

    # Compound classes are added if there corresponding component classes are present.
    self.data['label_list'] = self.data['label_list'].apply(lambda x: self._add_compound_classes(x))

    # If multiclass columns are required, they are added.
    if self.add_multiclass_columns:
      self._generate_multiclass_columns()

    # Auxiliary classes are removed from the label list. This were added for the sake of the class structure of the taxonomy, but they do not add value to the list of predictions.
    auxiliary_classes = ['axial_region', 'cardiac_image', 'other_thoracic_image', 'intravascular', 'ultrasound_view']
    self.data['label_list'] = self.data['label_list'].apply(lambda x: [item for item in x if item not in auxiliary_classes])

  ### Auxiliary Methods ###

  def _identify_upper_model(self, model_name):

    '''Method used to identify the corresponding upper model of a given model.'''

    colon_index = self._search_last_match(model_name, ':')
    dot_index = self._search_last_match(model_name, '.')
    index = max(colon_index, dot_index)
    if index != -1:
      return model_name[:index]
    else:
      return None  

  def _search_last_match(self, string, character):

    '''Method used to find the last mention of a character in a string.'''

    if character in string:
      return string.rindex(character)
    else:
      return -1

  def _add_predictions(self, model_name):

    '''Method used to add all the predictions of a given model to the outcome dataframe.'''

    upper_model = self._identify_upper_model(model_name)

    # Models are used depending on the outcome of models from a higher hierarchy.
    if upper_model is not None:
      condition_class = model_name.split(':')[-1].split('~')[0].split('.')[-1]
      condition = self.data[model_name].isnull() & (self.data[upper_model] == condition_class)
    else:
      condition = self.data[model_name].isnull()
    imgs = self.data[condition].image_path.values

    labels = np.array(self.label_dict[model_name])

    # Models are ran.
    if len(imgs) > 0:
      device = 'cpu'
      # checkpoint_file = os.path.join(self.models_root, model_name.replace(':', '_'), 'model')
      checkpoint_file = os.path.join(model_name.replace(':', '_'), 'model')
      dls = ImageDataLoaders.from_path_func('', imgs, lambda x: '0', item_tfms=Resize((224,224), method='squish'))
      learn = vision_learner(dls, resnet50, n_out=len(labels)).to_fp16()
      learn.load(checkpoint_file, device=device, weights_only=False)
      test_dl = learn.dls.test_dl(imgs, device=device)
      probs, _ = learn.get_preds(dl=test_dl)
      self.data.loc[condition, model_name] = labels[probs.argmax(axis=1)]
    
  def _add_compound_classes(self, input_class_list):

    '''This method is used to add compound classes to the label list if the corresponding component classes are present.'''

    compound_class_dicts = [
        {'compound_class': 'echocardiogram', 'components': ['ultrasound', 'cardiac_image']},
        {'compound_class': 'ivus', 'components': ['ultrasound', 'intravascular']},
        {'compound_class': 'mammography', 'components': ['x_ray', 'breast']}
        ]

    for dct in compound_class_dicts:
      condition = True
      for cls in dct['components']:
        if cls not in input_class_list:
          condition = False
          break
      if condition:
        if dct['compound_class'] not in input_class_list:
          input_class_list.append(dct['compound_class'])

    return input_class_list

  def _generate_multiclass_columns(self):

    '''Method used to generate the multiclass columns based on the label list.'''

    image_types = ['chart', 'radiology', 'pathology', 'medical_photograph', 'ophthalmic_imaging', 'endoscopy', 'electrography']
    self.data['image_type'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, image_types))

    image_subtypes = ['chart',
                      'ct', 'mri', 'x_ray', 'pet', 'spect', 'scintigraphy', 'ultrasound', 'tractography',
                      'acid_fast', 'alcian_blue', 'congo_red', 'fish', 'giemsa', 'gram', 'h&e', 'immunostaining', 'masson_trichrome', 'methenamine_silver', 'methylene_blue', 'papanicolaou', 'pas', 'van_gieson',
                      'skin_photograph', 'oral_photograph', 'other_medical_photograph',
                      'b_scan', 'autofluorescence', 'fundus_photograph', 'gonioscopy', 'oct', 'ophthalmic_angiography', 'slit_lamp_photograph',
                      'gi_endoscopy', 'airway_endoscopy', 'other_endoscopy', 'arthroscopy',
                      'eeg', 'emg', 'ekg']

    self.data['image_subtype'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, image_subtypes))

    anatomical_regions = ['abdomen', 'breast', 'head', 'neck', 'pelvis', 'thorax',
                          'lower_limb', 'upper_limb', 'whole_body']

    self.data['radiology_region'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, anatomical_regions))

    granular_anatomical_regions = ['abdomen', 'breast', 'head', 'neck', 'pelvis', 'thorax',
                                  'ankle', 'foot', 'hip', 'knee', 'lower_leg', 'thigh',
                                  'elbow', 'forearm', 'hand', 'shoulder', 'upper_arm', 'wrist',
                                  'whole_body']

    self.data['radiology_region_granular'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, granular_anatomical_regions))

    anatomical_view = ['axial', 'frontal', 'sagittal', 'oblique',
                      'occlusal', 'panoramic', 'periapical', 'intravascular', 'ultrasound_view']

    self.data['radiology_view'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, anatomical_view))

  def _get_column_label(self, column_list, label_list):

    '''Method used to get the label from a relevant list that is present in the predictions of a given image.'''

    label = ''
    for column in column_list:
      if column in label_list:
        label = column
    return label