import base64
from io import BytesIO
from time import perf_counter
from typing import Any, List, Optional, Union

import numpy as np
import onnxruntime
import rasterio.features
import torch
from segment_anything import SamPredictor, sam_model_registry
from shapely.geometry import Polygon as ShapelyPolygon

from inference.core.entities.requests.inference import InferenceRequestImage
from inference.core.entities.requests.sam import (
    SamEmbeddingRequest,
    SamInferenceRequest,
    SamSegmentationRequest,
)
from inference.core.entities.responses.sam import (
    SamEmbeddingResponse,
    SamSegmentationResponse,
)
from inference.core.env import SAM_MAX_EMBEDDING_CACHE_SIZE, SAM_VERSION_ID
from inference.core.models.roboflow import RoboflowCoreModel
from inference.core.utils.image_utils import load_image_rgb
from inference.core.utils.postprocess import masks2poly


class SegmentAnything(RoboflowCoreModel):
    """SegmentAnything class for handling segmentation tasks.

    Attributes:
        sam: The segmentation model.
        predictor: The predictor for the segmentation model.
        ort_session: ONNX runtime inference session.
        embedding_cache: Cache for embeddings.
        image_size_cache: Cache for image sizes.
        embedding_cache_keys: Keys for the embedding cache.
        low_res_logits_cache: Cache for low resolution logits.
        segmentation_cache_keys: Keys for the segmentation cache.
    """

    def __init__(self, *args, model_id: str = f"sam/{SAM_VERSION_ID}", **kwargs):
        """Initializes the SegmentAnything.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, model_id=model_id, **kwargs)
        self.sam = sam_model_registry[self.version_id](
            checkpoint=self.cache_file("encoder.pth")
        )
        self.sam.to(device="cuda" if torch.cuda.is_available() else "cpu")
        self.predictor = SamPredictor(self.sam)
        self.ort_session = onnxruntime.InferenceSession(
            self.cache_file("decoder.onnx"),
            providers=[
                "CUDAExecutionProvider",
                "CPUExecutionProvider",
            ],
        )
        self.embedding_cache = {}
        self.image_size_cache = {}
        self.embedding_cache_keys = []

        self.low_res_logits_cache = {}
        self.segmentation_cache_keys = []
        self.task_type = "unsupervised-segmentation"

    def get_infer_bucket_file_list(self) -> List[str]:
        """Gets the list of files required for inference.

        Returns:
            List[str]: List of file names.
        """
        return ["encoder.pth", "decoder.onnx"]

    def embed_image(self, image: Any, image_id: Optional[str] = None, **kwargs):
        """
        Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
        the cached result will be returned.

        Args:
            image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
            image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
                                      with this ID. Defaults to None.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
                                               and the second element is the shape (height, width) of the processed image.

        Notes:
            - Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
            - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.

        Example:
            >>> img_array = ... # some image array
            >>> embed_image(img_array, image_id="sample123")
            (array([...]), (224, 224))
        """
        if image_id and image_id in self.embedding_cache:
            return (
                self.embedding_cache[image_id],
                self.image_size_cache[image_id],
            )
        img_in = self.preproc_image(image)
        self.predictor.set_image(img_in)
        embedding = self.predictor.get_image_embedding().cpu().numpy()
        if image_id:
            self.embedding_cache[image_id] = embedding
            self.image_size_cache[image_id] = img_in.shape[:2]
            self.embedding_cache_keys.append(image_id)
            if len(self.embedding_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE:
                cache_key = self.embedding_cache_keys.pop(0)
                del self.embedding_cache[cache_key]
                del self.image_size_cache[cache_key]
        return (embedding, img_in.shape[:2])

    def infer_from_request(self, request: SamInferenceRequest):
        """Performs inference based on the request type.

        Args:
            request (SamInferenceRequest): The inference request.

        Returns:
            Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
        """
        t1 = perf_counter()
        if isinstance(request, SamEmbeddingRequest):
            embedding, _ = self.embed_image(**request.dict())
            inference_time = perf_counter() - t1
            if request.format == "json":
                return SamEmbeddingResponse(
                    embeddings=embedding.tolist(), time=inference_time
                )
            elif request.format == "binary":
                binary_vector = BytesIO()
                np.save(binary_vector, embedding)
                binary_vector.seek(0)
                return SamEmbeddingResponse(
                    embeddings=binary_vector.getvalue(), time=inference_time
                )
        elif isinstance(request, SamSegmentationRequest):
            masks, low_res_masks = self.segment_image(**request.dict())
            if request.format == "json":
                masks = masks > self.predictor.model.mask_threshold
                masks = masks2poly(masks)
                low_res_masks = low_res_masks > self.predictor.model.mask_threshold
                low_res_masks = masks2poly(low_res_masks)
            elif request.format == "binary":
                binary_vector = BytesIO()
                np.savez_compressed(
                    binary_vector, masks=masks, low_res_masks=low_res_masks
                )
                binary_vector.seek(0)
                binary_data = binary_vector.getvalue()
                return binary_data
            else:
                raise ValueError(f"Invalid format {request.format}")

            response = SamSegmentationResponse(
                masks=[m.tolist() for m in masks],
                low_res_masks=[m.tolist() for m in low_res_masks],
                time=perf_counter() - t1,
            )
            return response

    def preproc_image(self, image: InferenceRequestImage):
        """Preprocesses an image.

        Args:
            image (InferenceRequestImage): The image to preprocess.

        Returns:
            np.array: The preprocessed image.
        """
        np_image = load_image_rgb(image)
        return np_image

    def segment_image(
        self,
        image: Any,
        embeddings: Optional[Union[np.ndarray, List[List[float]]]] = None,
        embeddings_format: Optional[str] = "json",
        has_mask_input: Optional[bool] = False,
        image_id: Optional[str] = None,
        mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
        mask_input_format: Optional[str] = "json",
        orig_im_size: Optional[List[int]] = None,
        point_coords: Optional[List[List[float]]] = [],
        point_labels: Optional[List[int]] = [],
        use_mask_input_cache: Optional[bool] = True,
        **kwargs,
    ):
        """
        Segments an image based on provided embeddings, points, masks, or cached results.
        If embeddings are not directly provided, the function can derive them from the input image or cache.

        Args:
            image (Any): The image to be segmented.
            embeddings (Optional[Union[np.ndarray, List[List[float]]]]): The embeddings of the image.
                Defaults to None, in which case the image is used to compute embeddings.
            embeddings_format (Optional[str]): Format of the provided embeddings; either 'json' or 'binary'. Defaults to 'json'.
            has_mask_input (Optional[bool]): Specifies whether mask input is provided. Defaults to False.
            image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
            mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input mask for the image.
            mask_input_format (Optional[str]): Format of the provided mask input; either 'json' or 'binary'. Defaults to 'json'.
            orig_im_size (Optional[List[int]]): Original size of the image when providing embeddings directly.
            point_coords (Optional[List[List[float]]]): Coordinates of points in the image. Defaults to an empty list.
            point_labels (Optional[List[int]]): Labels associated with the provided points. Defaults to an empty list.
            use_mask_input_cache (Optional[bool]): Flag to determine if cached mask input should be used. Defaults to True.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, np.ndarray]: A tuple where the first element is the segmentation masks of the image
                                          and the second element is the low resolution segmentation masks.

        Raises:
            ValueError: If necessary inputs are missing or inconsistent.

        Notes:
            - Embeddings, segmentations, and low-resolution logits can be cached to improve performance
              on repeated requests for the same image.
            - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.
        """
        if not embeddings:
            if not image and not image_id:
                raise ValueError(
                    "Must provide either image, cached image_id, or embeddings"
                )
            elif image_id and not image and image_id not in self.embedding_cache:
                raise ValueError(
                    f"Image ID {image_id} not in embedding cache, must provide the image or embeddings"
                )
            embedding, original_image_size = self.embed_image(
                image=image, image_id=image_id
            )
        else:
            if not orig_im_size:
                raise ValueError(
                    "Must provide original image size if providing embeddings"
                )
            original_image_size = orig_im_size
            if embeddings_format == "json":
                embedding = np.array(embeddings)
            elif embeddings_format == "binary":
                embedding = np.load(BytesIO(embeddings))

        point_coords = point_coords
        point_coords.append([0, 0])
        point_coords = np.array(point_coords, dtype=np.float32)
        point_coords = np.expand_dims(point_coords, axis=0)
        point_coords = self.predictor.transform.apply_coords(
            point_coords,
            original_image_size,
        )

        point_labels = point_labels
        point_labels.append(-1)
        point_labels = np.array(point_labels, dtype=np.float32)
        point_labels = np.expand_dims(point_labels, axis=0)

        if has_mask_input:
            if (
                image_id
                and image_id in self.low_res_logits_cache
                and use_mask_input_cache
            ):
                mask_input = self.low_res_logits_cache[image_id]
            elif not mask_input and (
                not image_id or image_id not in self.low_res_logits_cache
            ):
                raise ValueError("Must provide either mask_input or cached image_id")
            else:
                if mask_input_format == "json":
                    polys = mask_input
                    mask_input = np.zeros((1, len(polys), 256, 256), dtype=np.uint8)
                    for i, poly in enumerate(polys):
                        poly = ShapelyPolygon(poly)
                        raster = rasterio.features.rasterize(
                            [poly], out_shape=(256, 256)
                        )
                        mask_input[0, i, :, :] = raster
                elif mask_input_format == "binary":
                    binary_data = base64.b64decode(mask_input)
                    mask_input = np.load(BytesIO(binary_data))
        else:
            mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)

        ort_inputs = {
            "image_embeddings": embedding.astype(np.float32),
            "point_coords": point_coords.astype(np.float32),
            "point_labels": point_labels,
            "mask_input": mask_input.astype(np.float32),
            "has_mask_input": (
                np.zeros(1, dtype=np.float32)
                if not has_mask_input
                else np.ones(1, dtype=np.float32)
            ),
            "orig_im_size": np.array(original_image_size, dtype=np.float32),
        }
        masks, _, low_res_logits = self.ort_session.run(None, ort_inputs)
        if image_id:
            self.low_res_logits_cache[image_id] = low_res_logits
            if image_id not in self.segmentation_cache_keys:
                self.segmentation_cache_keys.append(image_id)
            if len(self.segmentation_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE:
                cache_key = self.segmentation_cache_keys.pop(0)
                del self.low_res_logits_cache[cache_key]
        masks = masks[0]
        low_res_masks = low_res_logits[0]

        return masks, low_res_masks