from typing import Dict, List, Any, Union from sam2.sam2_image_predictor import SAM2ImagePredictor import torch import numpy as np from PIL import Image import io import base64 from huggingface_hub import InferenceEndpoint class EndpointHandler(InferenceEndpoint): def __init__(self, model_dir=None): """Initialize the handler with mock predictor for local testing Args: model_dir (str, optional): Path to model directory. Defaults to None. """ # Comment out real model for local testing self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small") # Mock predictor for local testing # class MockPredictor: # def set_image(self, image): # print(f"Mock: set_image called with shape {image.shape}") # def predict(self, point_coords=None, point_labels=None): # print("Mock: predict called") # if point_coords is not None: # print(f"Mock: with point coords {point_coords}") # print(f"Mock: with point labels {point_labels}") # # Return mock mask focused around the point # mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(1)] # mock_scores = np.array([0.95]) # Higher confidence for point prompt # else: # # Return multiple mock masks for automatic mode # mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(3)] # mock_scores = np.array([0.9, 0.8, 0.7]) # return mock_masks, mock_scores, None # self.predictor = MockPredictor() def _load_image(self, image_data: Union[str, bytes]) -> Image.Image: """Load image from binary or base64 data""" try: # Handle base64 encoded data if isinstance(image_data, str): image_data = base64.b64decode(image_data) # Convert bytes to PIL Image image = Image.open(io.BytesIO(image_data)) return image except Exception as e: raise ValueError(f"Failed to load image: {str(e)}") def __call__(self, image_bytes): # Get point prompts if provided in request if isinstance(image_bytes, dict): point_coords = image_bytes.get('point_coords') point_labels = image_bytes.get('point_labels') image_bytes = image_bytes['image'] else: point_coords = None point_labels = None # Convert bytes to image image = Image.open(io.BytesIO(image_bytes)) if image.mode != 'RGB': image = image.convert('RGB') image_array = np.array(image) # Run inference (will use mock predictor locally) with torch.inference_mode(): if torch.cuda.is_available(): with torch.autocast("cuda", dtype=torch.bfloat16): self.predictor.set_image(image_array) masks, scores, _ = self.predictor.predict( point_coords=point_coords, point_labels=point_labels ) else: self.predictor.set_image(image_array) masks, scores, _ = self.predictor.predict( point_coords=point_coords, point_labels=point_labels ) # Format output if masks is not None: return { "masks": [mask.tolist() for mask in masks], "scores": scores.tolist() if scores is not None else None, "status": "success" } return {"error": "No masks generated", "status": "error"}