mauro-nievoff commited on
Commit
5302093
·
verified ·
1 Parent(s): 3239c8b

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +6 -5
pipeline.py CHANGED
@@ -35,7 +35,7 @@ class MultiCaReClassifier():
35
 
36
  # The outcome dataframe is created.
37
  self.image_paths = get_image_files(self.image_folder)
38
- self.data = pd.DataFrame(columns=[name for name in self.label_dict.keys() if os.path.isdir(os.path.join(models_root, name.replace(':', '_')))])
39
  self.data['image_path'] = self.image_paths
40
  self.predict_image_classes()
41
 
@@ -95,8 +95,8 @@ class MultiCaReClassifier():
95
 
96
  '''Method used to identify the corresponding upper model of a given model.'''
97
 
98
- colon_index = search_last_match(model_name, ':')
99
- dot_index = search_last_match(model_name, '.')
100
  index = max(colon_index, dot_index)
101
  if index != -1:
102
  return model_name[:index]
@@ -131,7 +131,8 @@ class MultiCaReClassifier():
131
  # Models are ran.
132
  if len(imgs) > 0:
133
  device = 'cpu'
134
- checkpoint_file = os.path.join(self.models_root, model_name.replace(':', '_'), 'model')
 
135
  dls = ImageDataLoaders.from_path_func('', imgs, lambda x: '0', item_tfms=Resize((224,224), method='squish'))
136
  learn = vision_learner(dls, resnet50, n_out=len(labels)).to_fp16()
137
  learn.load(checkpoint_file, device=device)
@@ -203,4 +204,4 @@ class MultiCaReClassifier():
203
  for column in column_list:
204
  if column in label_list:
205
  label = column
206
- return label
 
35
 
36
  # The outcome dataframe is created.
37
  self.image_paths = get_image_files(self.image_folder)
38
+ self.data = pd.DataFrame(columns=[name for name in self.label_dict.keys() if os.path.isdir(os.path.join('models', name.replace(':', '_')))])
39
  self.data['image_path'] = self.image_paths
40
  self.predict_image_classes()
41
 
 
95
 
96
  '''Method used to identify the corresponding upper model of a given model.'''
97
 
98
+ colon_index = self._search_last_match(model_name, ':')
99
+ dot_index = self._search_last_match(model_name, '.')
100
  index = max(colon_index, dot_index)
101
  if index != -1:
102
  return model_name[:index]
 
131
  # Models are ran.
132
  if len(imgs) > 0:
133
  device = 'cpu'
134
+ # checkpoint_file = os.path.join(self.models_root, model_name.replace(':', '_'), 'model')
135
+ checkpoint_file = os.path.join(model_name.replace(':', '_'), 'model')
136
  dls = ImageDataLoaders.from_path_func('', imgs, lambda x: '0', item_tfms=Resize((224,224), method='squish'))
137
  learn = vision_learner(dls, resnet50, n_out=len(labels)).to_fp16()
138
  learn.load(checkpoint_file, device=device)
 
204
  for column in column_list:
205
  if column in label_list:
206
  label = column
207
+ return label