tyriaa commited on
Commit
5ff1fa1
·
1 Parent(s): 66259e6

Initialisation 7

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -4,7 +4,31 @@ import os
4
  import shutil
5
  import numpy as np
6
  from PIL import Image
7
- from utils.predictor import Predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from utils.helpers import (
9
  blend_mask_with_image,
10
  save_mask_as_png,
 
4
  import shutil
5
  import numpy as np
6
  from PIL import Image
7
+ from sam2.build_sam import build_sam2
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+
10
+ class Predictor:
11
+ def __init__(self, model_cfg, checkpoint, device):
12
+ self.device = device
13
+ self.model = build_sam2(model_cfg, checkpoint, device=device)
14
+ self.predictor = SAM2ImagePredictor(self.model)
15
+ self.image_set = False
16
+
17
+ def set_image(self, image):
18
+ """Set the image for SAM prediction."""
19
+ self.image = image
20
+ self.predictor.set_image(image)
21
+ self.image_set = True
22
+
23
+ def predict(self, point_coords, point_labels, multimask_output=False):
24
+ """Run SAM prediction."""
25
+ if not self.image_set:
26
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
27
+ return self.predictor.predict(
28
+ point_coords=point_coords,
29
+ point_labels=point_labels,
30
+ multimask_output=multimask_output
31
+ )
32
  from utils.helpers import (
33
  blend_mask_with_image,
34
  save_mask_as_png,