Mask Generation
sam2
Tony Neel commited on
Commit
8ba5658
·
1 Parent(s): 41e99e7

Add custom handler for SAM2

Browse files
handler.py CHANGED
@@ -1,41 +1,47 @@
1
  from typing import Dict, List, Any
2
- from sam2.sam2_image_predictor import SAM2ImagePredictor
3
  import torch
4
- import numpy as np
5
- from PIL import Image
6
- import io
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
- # Initialize SAM2 predictor with small model
11
- self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
12
-
13
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
14
  """
 
15
  Args:
16
- data: Dictionary with "inputs" key containing image bytes
 
17
  Returns:
18
- Dictionary containing masks and scores
19
  """
20
- # Get input image
21
- if "inputs" not in data:
22
- raise ValueError("No inputs provided")
23
-
24
- # Convert input image bytes to PIL Image
25
- image = Image.open(io.BytesIO(data["inputs"]))
26
- image = np.array(image)
27
 
28
- # Process with SAM2
29
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
30
- self.predictor.set_image(image)
31
- masks, scores, _ = self.predictor.predict()
32
-
33
- # Convert masks to lists for JSON serialization
34
- if masks is not None:
35
- masks = [mask.tolist() for mask in masks]
36
- scores = scores.tolist() if scores is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- return {
39
- "masks": masks,
40
- "scores": scores
41
- }
 
1
  from typing import Dict, List, Any
2
+ from transformers import SamModel, SamProcessor
3
  import torch
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ self.model = SamModel.from_pretrained(path).to(self.device)
9
+ self.processor = SamProcessor.from_pretrained(path)
10
+
11
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
  """
13
+ Handle image segmentation requests
14
  Args:
15
+ data: Dictionary containing:
16
+ inputs: Raw image bytes
17
  Returns:
18
+ List of dictionaries containing segmentation masks
19
  """
20
+ # Get raw image bytes from the request
21
+ raw_image = data.pop("inputs", data)
 
 
 
 
 
22
 
23
+ # Process the image
24
+ inputs = self.processor(raw_image, return_tensors="pt").to(self.device)
25
+
26
+ # Generate image embeddings
27
+ image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
28
+
29
+ # Generate masks
30
+ outputs = self.model.generate(
31
+ image_embeddings=image_embeddings,
32
+ return_dict=True
33
+ )
34
+
35
+ # Process outputs
36
+ masks = outputs.pred_masks.squeeze().cpu().numpy()
37
+ scores = outputs.iou_scores.squeeze().cpu().numpy()
38
+
39
+ # Format response
40
+ results = []
41
+ for mask, score in zip(masks, scores):
42
+ results.append({
43
+ "mask": mask.tolist(), # Convert numpy array to list for JSON serialization
44
+ "score": float(score)
45
+ })
46
 
47
+ return results
 
 
 
images/20250121_gauge_0001.jpg ADDED
requirements.txt CHANGED
@@ -1 +1,5 @@
1
  sam2
 
 
 
 
 
1
  sam2
2
+ transformers
3
+ torch
4
+ pillow
5
+ numpy
test_endpoint.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from pathlib import Path
3
+ from PIL import Image
4
+ import io
5
+
6
+ def get_stored_token():
7
+ """Get the stored HuggingFace token"""
8
+ token_path = Path.home() / '.cache/huggingface/token'
9
+ if token_path.exists():
10
+ with open(token_path, 'r') as f:
11
+ return f.read().strip()
12
+ return None
13
+
14
+ # Update API URL to use the inference API endpoint
15
+ API_URL = "https://c3g262qlc7cizj5n.us-east4.gcp.endpoints.huggingface.cloud"
16
+ token = get_stored_token()
17
+
18
+ def query(image_path):
19
+ # Read image bytes directly
20
+ with open(image_path, "rb") as f:
21
+ image_bytes = f.read()
22
+
23
+ headers = {
24
+ "Authorization": f"Bearer {token}",
25
+ "Content-Type": "image/jpeg"
26
+ }
27
+
28
+ # Print some debug info
29
+ print(f"Sending file: {image_path}")
30
+ print(f"Content-Type: {headers['Content-Type']}")
31
+ print(f"Image size: {len(image_bytes)} bytes")
32
+
33
+ response = requests.post(
34
+ API_URL,
35
+ headers=headers,
36
+ data=image_bytes, # Send raw bytes
37
+ verify=True
38
+ )
39
+
40
+ # Add error handling
41
+ if response.status_code != 200:
42
+ print(f"Response headers: {response.headers}")
43
+ print(f"Request headers sent: {response.request.headers}")
44
+ return f"Error: {response.status_code}, {response.text}"
45
+ try:
46
+ return response.json()
47
+ except requests.exceptions.JSONDecodeError:
48
+ return f"Error decoding JSON. Raw response: {response.text}"
49
+
50
+ # Test with an image
51
+ if __name__ == "__main__":
52
+ # Option 1: Test with specific image
53
+ image_path = Path("images/20250121_gauge_0001.jpg")
54
+
55
+ # Option 2: Test with first image found in directory
56
+ # TRAIN_IMAGES_DIR = Path("images")
57
+ # image_path = next(TRAIN_IMAGES_DIR.glob('*.jpg'))
58
+
59
+ if not image_path.exists():
60
+ print(f"Error: Image not found at {image_path}")
61
+ exit(1)
62
+
63
+ print(f"Testing with image: {image_path}")
64
+ result = query(image_path)
65
+ print("\nAPI Response:")
66
+ print(result)