Update pipeline.py
Browse files- pipeline.py +6 -5
pipeline.py
CHANGED
@@ -35,7 +35,7 @@ class MultiCaReClassifier():
|
|
35 |
|
36 |
# The outcome dataframe is created.
|
37 |
self.image_paths = get_image_files(self.image_folder)
|
38 |
-
self.data = pd.DataFrame(columns=[name for name in self.label_dict.keys() if os.path.isdir(os.path.join(
|
39 |
self.data['image_path'] = self.image_paths
|
40 |
self.predict_image_classes()
|
41 |
|
@@ -95,8 +95,8 @@ class MultiCaReClassifier():
|
|
95 |
|
96 |
'''Method used to identify the corresponding upper model of a given model.'''
|
97 |
|
98 |
-
colon_index =
|
99 |
-
dot_index =
|
100 |
index = max(colon_index, dot_index)
|
101 |
if index != -1:
|
102 |
return model_name[:index]
|
@@ -131,7 +131,8 @@ class MultiCaReClassifier():
|
|
131 |
# Models are ran.
|
132 |
if len(imgs) > 0:
|
133 |
device = 'cpu'
|
134 |
-
checkpoint_file = os.path.join(self.models_root, model_name.replace(':', '_'), 'model')
|
|
|
135 |
dls = ImageDataLoaders.from_path_func('', imgs, lambda x: '0', item_tfms=Resize((224,224), method='squish'))
|
136 |
learn = vision_learner(dls, resnet50, n_out=len(labels)).to_fp16()
|
137 |
learn.load(checkpoint_file, device=device)
|
@@ -203,4 +204,4 @@ class MultiCaReClassifier():
|
|
203 |
for column in column_list:
|
204 |
if column in label_list:
|
205 |
label = column
|
206 |
-
return label
|
|
|
35 |
|
36 |
# The outcome dataframe is created.
|
37 |
self.image_paths = get_image_files(self.image_folder)
|
38 |
+
self.data = pd.DataFrame(columns=[name for name in self.label_dict.keys() if os.path.isdir(os.path.join('models', name.replace(':', '_')))])
|
39 |
self.data['image_path'] = self.image_paths
|
40 |
self.predict_image_classes()
|
41 |
|
|
|
95 |
|
96 |
'''Method used to identify the corresponding upper model of a given model.'''
|
97 |
|
98 |
+
colon_index = self._search_last_match(model_name, ':')
|
99 |
+
dot_index = self._search_last_match(model_name, '.')
|
100 |
index = max(colon_index, dot_index)
|
101 |
if index != -1:
|
102 |
return model_name[:index]
|
|
|
131 |
# Models are ran.
|
132 |
if len(imgs) > 0:
|
133 |
device = 'cpu'
|
134 |
+
# checkpoint_file = os.path.join(self.models_root, model_name.replace(':', '_'), 'model')
|
135 |
+
checkpoint_file = os.path.join(model_name.replace(':', '_'), 'model')
|
136 |
dls = ImageDataLoaders.from_path_func('', imgs, lambda x: '0', item_tfms=Resize((224,224), method='squish'))
|
137 |
learn = vision_learner(dls, resnet50, n_out=len(labels)).to_fp16()
|
138 |
learn.load(checkpoint_file, device=device)
|
|
|
204 |
for column in column_list:
|
205 |
if column in label_list:
|
206 |
label = column
|
207 |
+
return label
|