Mask Generation
sam2
Tony Neel commited on
Commit
41e99e7
·
1 Parent(s): 6c381d9

Add custom handler for SAM2

Browse files
Files changed (2) hide show
  1. handler.py +41 -0
  2. requirements.txt +1 -0
handler.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import io
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path=""):
10
+ # Initialize SAM2 predictor with small model
11
+ self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
12
+
13
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
14
+ """
15
+ Args:
16
+ data: Dictionary with "inputs" key containing image bytes
17
+ Returns:
18
+ Dictionary containing masks and scores
19
+ """
20
+ # Get input image
21
+ if "inputs" not in data:
22
+ raise ValueError("No inputs provided")
23
+
24
+ # Convert input image bytes to PIL Image
25
+ image = Image.open(io.BytesIO(data["inputs"]))
26
+ image = np.array(image)
27
+
28
+ # Process with SAM2
29
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
30
+ self.predictor.set_image(image)
31
+ masks, scores, _ = self.predictor.predict()
32
+
33
+ # Convert masks to lists for JSON serialization
34
+ if masks is not None:
35
+ masks = [mask.tolist() for mask in masks]
36
+ scores = scores.tolist() if scores is not None else None
37
+
38
+ return {
39
+ "masks": masks,
40
+ "scores": scores
41
+ }
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ sam2