Update pipeline.py
Browse files- pipeline.py +1 -1
pipeline.py
CHANGED
@@ -135,7 +135,7 @@ class MultiCaReClassifier():
|
|
135 |
checkpoint_file = os.path.join(model_name.replace(':', '_'), 'model')
|
136 |
dls = ImageDataLoaders.from_path_func('', imgs, lambda x: '0', item_tfms=Resize((224,224), method='squish'))
|
137 |
learn = vision_learner(dls, resnet50, n_out=len(labels)).to_fp16()
|
138 |
-
learn.load(checkpoint_file, device=device)
|
139 |
test_dl = learn.dls.test_dl(imgs, device=device)
|
140 |
probs, _ = learn.get_preds(dl=test_dl)
|
141 |
self.data.loc[condition, model_name] = labels[probs.argmax(axis=1)]
|
|
|
135 |
checkpoint_file = os.path.join(model_name.replace(':', '_'), 'model')
|
136 |
dls = ImageDataLoaders.from_path_func('', imgs, lambda x: '0', item_tfms=Resize((224,224), method='squish'))
|
137 |
learn = vision_learner(dls, resnet50, n_out=len(labels)).to_fp16()
|
138 |
+
learn.load(checkpoint_file, device=device, weights_only=False)
|
139 |
test_dl = learn.dls.test_dl(imgs, device=device)
|
140 |
probs, _ = learn.get_preds(dl=test_dl)
|
141 |
self.data.loc[condition, model_name] = labels[probs.argmax(axis=1)]
|