jens commited on
Commit
b66242e
·
1 Parent(s): a0e14ae
Files changed (1) hide show
  1. inference.py +2 -2
inference.py CHANGED
@@ -121,8 +121,8 @@ class DepthPredictor:
121
 
122
  class SegmentPredictor:
123
  def __init__(self):
124
- MODEL_TYPE = "vit_b"
125
- checkpoint = "sam_vit_b_01ec64.pth"
126
  sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
127
  # Select device
128
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
121
 
122
  class SegmentPredictor:
123
  def __init__(self):
124
+ MODEL_TYPE = "vit_h"
125
+ checkpoint = "sam_vit_h_4b8939.pth"
126
  sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
127
  # Select device
128
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'