from time import perf_counter
from typing import Any, List, Tuple, Union

import numpy as np

from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.models.types import PreprocessReturnMetadata


class BaseInference:
    """General inference class.

    This class provides a basic interface for inference tasks.
    """

    def infer(self, image: Any, **kwargs) -> Any:
        """Runs inference on given data."""
        preproc_image, returned_metadata = self.preprocess(image, **kwargs)
        predicted_arrays = self.predict(preproc_image, **kwargs)
        postprocessed = self.postprocess(predicted_arrays, returned_metadata, **kwargs)

        return postprocessed

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        raise NotImplementedError

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, ...]:
        raise NotImplementedError

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, ...],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs
    ) -> Any:
        raise NotImplementedError

    def infer_from_request(
        self, request: InferenceRequest
    ) -> Union[InferenceResponse, List[InferenceResponse]]:
        """Runs inference on a request

        Args:
            request (InferenceRequest): The request object.

        Returns:
            Union[CVInferenceResponse, List[CVInferenceResponse]]: The response object(s).

        Raises:
            NotImplementedError: This method must be implemented by a subclass.
        """
        raise NotImplementedError

    def make_response(
        self, *args, **kwargs
    ) -> Union[InferenceResponse, List[InferenceResponse]]:
        """Constructs an object detection response.

        Raises:
            NotImplementedError: This method must be implemented by a subclass.
        """
        raise NotImplementedError


class Model(BaseInference):
    """Base Inference Model (Inherits from BaseInference to define the needed methods)

    This class provides the foundational methods for inference and logging, and can be extended by specific models.

    Methods:
        log(m): Print the given message.
        clear_cache(): Clears any cache if necessary.
    """

    def log(self, m):
        """Prints the given message.

        Args:
            m (str): The message to print.
        """
        print(m)

    def clear_cache(self):
        """Clears any cache if necessary. This method should be implemented in derived classes as needed."""
        pass

    def infer_from_request(
        self,
        request: InferenceRequest,
    ) -> Union[List[InferenceResponse], InferenceResponse]:
        """
        Perform inference based on the details provided in the request, and return the associated responses.
        The function can handle both single and multiple image inference requests. Optionally, it also provides
        a visualization of the predictions if requested.

        Args:
            request (InferenceRequest): The request object containing details for inference, such as the image or
                images to process, any classes to filter by, and whether or not to visualize the predictions.

        Returns:
            Union[List[InferenceResponse], InferenceResponse]: A list of response objects if the request contains
            multiple images, or a single response object if the request contains one image. Each response object
            contains details about the segmented instances, the time taken for inference, and optionally, a visualization.

        Examples:
            >>> request = InferenceRequest(image=my_image, visualize_predictions=True)
            >>> response = infer_from_request(request)
            >>> print(response.time)  # Prints the time taken for inference
            0.125
            >>> print(response.visualization)  # Accesses the visualization of the prediction if available

        Notes:
            - The processing time for each response is included within the response itself.
            - If `visualize_predictions` is set to True in the request, a visualization of the prediction
              is also included in the response.
        """
        t1 = perf_counter()
        responses = self.infer(**request.dict(), return_image_dims=False)
        for response in responses:
            response.time = perf_counter() - t1

        if request.visualize_predictions:
            for response in responses:
                response.visualization = self.draw_predictions(request, response)

        if not isinstance(request.image, list) and len(responses) > 0:
            responses = responses[0]

        return responses

    def make_response(
        self, *args, **kwargs
    ) -> Union[InferenceResponse, List[InferenceResponse]]:
        """Makes an inference response from the given arguments.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        Returns:
            InferenceResponse: The inference response.
        """
        raise NotImplementedError(self.__class__.__name__ + ".make_response")