"Filter definitions, with pre-processing, post-processing and compilation methods."

import numpy as np
import torch
from torch import nn
from common import AVAILABLE_FILTERS, INPUT_SHAPE

from concrete.numpy.compilation.compiler import Compiler
from concrete.ml.common.utils import generate_proxy_function
from concrete.ml.torch.numpy_module import NumpyModule


class TorchIdentity(nn.Module):
    """Torch identity model."""

    def forward(self, x):
        """Identity forward pass.

        Args:
            x (torch.Tensor): The input image.

        Returns:
            x (torch.Tensor): The input image.
        """
        return x


class TorchInverted(nn.Module):
    """Torch inverted model."""

    def forward(self, x):
        """Forward pass for inverting an image's colors.

        Args:
            x (torch.Tensor): The input image.

        Returns:
            torch.Tensor: The (color) inverted image.
        """
        return 255 - x


class TorchRotate(nn.Module):
    """Torch rotated model."""

    def forward(self, x):
        """Forward pass for rotating an image.

        Args:
            x (torch.Tensor): The input image.

        Returns:
            torch.Tensor: The rotated image.
        """
        return x.transpose(0, 1)


class TorchConv(nn.Module):
    """Torch model with a single convolution operator."""

    def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None):
        """Initialize the filter.

        Args:
            kernel (np.ndarray): The convolution kernel to consider.
        """
        super().__init__()
        self.kernel = torch.tensor(kernel, dtype=torch.int64)
        self.n_out_channels = n_out_channels
        self.n_in_channels = n_in_channels
        self.groups = groups
        self.threshold = threshold

    def forward(self, x):
        """Forward pass with a single convolution using a 1D or 2D kernel.

        Args:
            x (torch.Tensor): The input image.

        Returns:
            torch.Tensor: The filtered image.
        """
        # Define the convolution parameters
        stride = 1
        kernel_shape = self.kernel.shape

        # Ensure the kernel has a proper shape
        # If the kernel has a 1D shape, a (1, 1) kernel is used for each in_channels
        if len(kernel_shape) == 1:
            kernel = self.kernel.reshape(
                self.n_out_channels,
                self.n_in_channels // self.groups,
                1,
                1,
            )

        # Else, if the kernel has a 2D shape, a single (Kw, Kh) kernel is used on all in_channels
        elif len(kernel_shape) == 2:
            kernel = self.kernel.expand(
                self.n_out_channels,
                self.n_in_channels // self.groups,
                kernel_shape[0],
                kernel_shape[1],
            )


        else:
            raise ValueError(
                "Wrong kernel shape, only 1D or 2D kernels are accepted. Got kernel of shape "
                f"{kernel_shape}"
            )

        # Reshape the image. This is done because Torch convolutions and Numpy arrays (for PIL 
        # display) don't follow the same shape conventions. More precisely, x is of shape 
        # (Width, Height, Channels) while the conv2d operator requires an input of shape 
        # (Batch, Channels, Height, Width)
        x = x.transpose(2, 0).unsqueeze(axis=0)

        # Apply the convolution
        x = nn.functional.conv2d(x, kernel, stride=stride, groups=self.groups)

        # Reshape the output back to the original shape (Width, Height, Channels)
        x = x.transpose(1, 3).reshape((x.shape[2], x.shape[3], self.n_out_channels))

        # Subtract a given threshold if given
        if self.threshold is not None:
            x -= self.threshold

        return x


class Filter:
    """Filter class used in the app."""

    def __init__(self, filter_name):
        """Initializing the filter class using a given filter.

        Most filters can be found at https://en.wikipedia.org/wiki/Kernel_(image_processing).

        Args:
            filter_name (str): The filter to consider.
        """

        assert filter_name in AVAILABLE_FILTERS, (
            f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, "
            f"but got {filter_name}",
        )

        # Define attributes associated to the filter 
        self.filter_name = filter_name
        self.onnx_model = None
        self.fhe_circuit = None
        self.divide = None
        self.repeat_out_channels = False

        # Instantiate the torch module associated to the given filter name 
        if filter_name == "identity":
            self.torch_model = TorchIdentity()

        elif filter_name == "inverted":
            self.torch_model = TorchInverted()

        elif filter_name == "rotate":
            self.torch_model = TorchRotate()

        elif filter_name == "black and white":
            # Define the grayscale weights (RGB order)
            # These weights were used in PAL and NTSC video systems and can be found at
            # https://en.wikipedia.org/wiki/Grayscale
            # There are initially supposed to be float weights (0.299, 0.587, 0.114), with
            # 0.299 + 0.587 + 0.114 = 1
            # However, since FHE computations require weights to be integers, we first multiply
            # these by a factor of 1000. The output image's values are then divided by 1000 in
            # post-processing in order to retrieve the correct result
            kernel = [299, 587, 114]

            self.torch_model = TorchConv(kernel, n_out_channels=1, groups=1)

            # Define the value used when for dividing the output values in post-processing
            self.divide = 1000

            # Indicate that the out_channels will need to be repeated, as Gradio requires all 
            # images to have a RGB format, even for grayscaled ones
            self.repeat_out_channels = True

        elif filter_name == "blur":
            kernel = np.ones((3, 3))

            self.torch_model = TorchConv(kernel, n_out_channels=3, groups=3)

            # Define the value used when for dividing the output values in post-processing
            self.divide = 9

        elif filter_name == "sharpen":
            kernel = [
                [0, -1, 0],
                [-1, 5, -1],
                [0, -1, 0],
            ]

            self.torch_model = TorchConv(kernel, n_out_channels=3, groups=3)

        elif filter_name == "ridge detection":
            kernel = [
                [-1, -1, -1],
                [-1, 9, -1],
                [-1, -1, -1],
            ]

            # Additionally to the convolution operator, the filter will subtract a given threshold
            # value to the result in order to better display the ridges
            self.torch_model = TorchConv(kernel, n_out_channels=1, groups=1, threshold=900)

            # Indicate that the out_channels will need to be repeated, as Gradio requires all 
            # images to have a RGB format, even for grayscaled ones. Ridge detection images are 
            # ususally displayed as such
            self.repeat_out_channels = True

    def compile(self):
        """Compile the filter on a representative inputset."""
        # Generate a random representative set of images used for compilation, following Torch's 
        # shape format (batch, in_channels, image_height, image_width) for each samples
        # This version's compiler only handles tuples of 1-batch array as inputset, meaning we need
        # to define the inputset as a Tuple[np.ndarray[shape=(1, 3, H, W)]]  
        np.random.seed(42)
        # inputset = tuple(
        #     np.random.randint(0, 256, size=((1, 3) + INPUT_SHAPE), dtype=np.int64) for _ in range(100)
        # )
        inputset = tuple(
            np.random.randint(0, 256, size=(INPUT_SHAPE + (3, )), dtype=np.int64) for _ in range(100)
        )

        # Convert the Torch module to a Numpy module
        numpy_module = NumpyModule(
            self.torch_model,
            dummy_input=torch.from_numpy(inputset[0]),
        )

        # Get the proxy function and parameter mappings used for initializing the compiler
        # This is done in order to be able to provide any modules with arbitrary numbers of 
        # encrypted arguments to Concrete Numpy's compiler
        numpy_filter_proxy, parameters_mapping = generate_proxy_function(
            numpy_module.numpy_forward, 
            ["inputs"]
        )

        # Compile the filter and retrieve its FHE circuit
        compiler = Compiler(
            numpy_filter_proxy,
            {parameters_mapping["inputs"]: "encrypted"},
        )
        self.fhe_circuit = compiler.compile(inputset)

        return self.fhe_circuit

    def post_processing(self, output_image):
        """Apply post-processing to the encrypted output images.

        Args:
            input_image (np.ndarray): The decrypted image to post-process.

        Returns:
            input_image (np.ndarray): The post-processed image.
        """
        # Divide all values if needed
        if self.divide is not None:
            output_image //= self.divide

        # Clip the image's values to proper RGB standards as filters don't handle such constraints
        output_image = output_image.clip(0, 255)

        # Gradio requires all images to follow a RGB format
        if self.repeat_out_channels:
            output_image = output_image.repeat(3, axis=2)

        return output_image