Sartc commited on
Commit
3e855cc
·
verified ·
1 Parent(s): a0d8b60

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +6 -5
inference.py CHANGED
@@ -30,8 +30,9 @@ def save_output(mask, save_path):
30
  mask_image = Image.fromarray(mask[0])
31
  mask_image.save(save_path)
32
 
33
- weights_path = "unet_model.pth"
34
- model = load_model(weights_path, device)
35
- image_tensor = preprocess_image("DUTS-TE-Image/ILSVRC2012_test_00000003.jpg")
36
- mask = predict(model, image_tensor, device)
37
- save_output(mask, "predicted_mask.jpg")
 
 
30
  mask_image = Image.fromarray(mask[0])
31
  mask_image.save(save_path)
32
 
33
+ if __name__ == "__main__":
34
+ weights_path = "unet_model.pth"
35
+ model = load_model(weights_path, device)
36
+ image_tensor = preprocess_image("DUTS-TE-Image/ILSVRC2012_test_00000003.jpg")
37
+ mask = predict(model, image_tensor, device)
38
+ save_output(mask, "predicted_mask.jpg")