Update inference.py
Browse files- inference.py +6 -5
inference.py
CHANGED
@@ -30,8 +30,9 @@ def save_output(mask, save_path):
|
|
30 |
mask_image = Image.fromarray(mask[0])
|
31 |
mask_image.save(save_path)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
30 |
mask_image = Image.fromarray(mask[0])
|
31 |
mask_image.save(save_path)
|
32 |
|
33 |
+
if __name__ == "__main__":
|
34 |
+
weights_path = "unet_model.pth"
|
35 |
+
model = load_model(weights_path, device)
|
36 |
+
image_tensor = preprocess_image("DUTS-TE-Image/ILSVRC2012_test_00000003.jpg")
|
37 |
+
mask = predict(model, image_tensor, device)
|
38 |
+
save_output(mask, "predicted_mask.jpg")
|