import numpy as np
class GraspCoder:
    """
    This class is to encode grasp annotations similar to BoxCoder class
    It is supposed to support the following functions:
        1. Encode grasp annotations:
            (x1, y1, x2, y2, x3, y3, x4, y4) -> (x_center, y_center, width, height, sine(theta))
        2. Decode grasp annotations:
            (x_center, y_center, width, height, sine(theta)) -> (x1, y1, x2, y2, x3, y3, x4, y4)
        3. Resize box grasp annotations when resizing image
        4. Transform box according to various image augmentations
    One GraspCoder class should encode annotations of one image only
    """
    def __init__(self, height, width, grasp_annos, grasp_annos_reformat=None):
        """

        Args:
            height: height of image
            width: width of image
            grasp_annos: list of numpy.arrays, each of length 8, in format of (x1, y1, x2, y2, x3, y3, x4, y4)
        """
        self.height = height
        self.width = width
        self.grasp_annos = grasp_annos
        self.grasp_annos_reformat = grasp_annos_reformat
    def __len__(self):
        return len(self.grasp_annos)
    def encode(self, normalize=True):
        """
        (x1, y1, x2, y2, x3, y3, x4, y4) -> (x_center, y_center, width, height, sine(theta))
        Args:
            normalize -> bool: return values normalized to 0~1 or not
        Returns:
            grasp_annos_reformat: List of numpy.array
        """
        grasp_annos_reformat = []
        for grasp in self.grasp_annos:
            x1, y1, x2, y2, x3, y3, x4, y4 = tuple(grasp)
            if (x1 + x2) < (x3 + x4):
                x1, y1, x2, y2, x3, y3, x4, y4 = x3, y3, x4, y4, x1, y1, x2, y2
            x_center = (x1 + x3)/2
            y_center = (y1 + y3)/2
            width = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
            height = np.sqrt((x2 - x3)**2 + (y2 - y3)**2)
            sine = ((y1 + y2)/2 - y_center) / (height / 2)
            if normalize:
                x_center /= self.width
                y_center /= self.height
                width /= self.width
                height /= self.height
                sine = (sine + 1) / 2
            grasp_annos_reformat.append(np.array([x_center, y_center, width, height, sine]))
        self.grasp_annos_reformat = grasp_annos_reformat
        return grasp_annos_reformat
    def decode(self):
        """
        Decode normalized grasp_annos_reformat, will overwrite self.grasp_annos, and return the overwritten value
        (x1, y1, x2, y2, x3, y3, x4, y4) -> (x_center, y_center, width, height, sine(theta))
        Returns:
            grasp_annos: List of numpy.array
        """
        grasp_annos = []
        for grasp in self.grasp_annos_reformat:
            x_center, y_center, width, height, sine = tuple(grasp)
            x_center *= self.width
            y_center *= self.height
            width *= self.width
            height *= self.height
            sine = sine * 2 - 1
            cosine = np.sqrt(1 - sine ** 2)
            angle = np.arcsin(sine)
            x1 = x_center + cosine * height / 2 + sine * width / 2
            x2 = x_center + cosine * height / 2 - sine * width / 2
            y1 = y_center + sine * height / 2 - cosine * width / 2
            y2 = y_center + sine * height / 2 + cosine * width / 2
            x3 = x_center * 2 - x1
            x4 = x_center * 2 - x2
            y3 = y_center * 2 - y1
            y4 = y_center * 2 - y2
            grasp_annos.append(np.array([x1, y1, x2, y2, x3, y3, x4, y4]))
        self.grasp_annos = grasp_annos
        return grasp_annos

    def resize(self, new_size):
        """
        Resize the grasp annotations according to resized image
        Args:
            new_size -> Tuple: (new_width, new_height)
            new_height: The resized image height
            new_width: The resized image width

        Returns:
            self
        """
        new_width, new_height = new_size
        grasp_annos = self.grasp_annos
        old_height, old_width = self.height, self.width
        resized_grasp_annos = []
        for grasp in grasp_annos:
            grasp[0::2] = grasp[0::2] / old_width * new_width
            grasp[1::2] = grasp[1::2] / old_height * new_height
            resized_grasp_annos.append(grasp)
        self.grasp_annos = resized_grasp_annos
        self.height, self.width = new_height, new_width

        return self
    def transpose(self, axis):
        """
        For Horizontal/Vertical flip
        Args:
            axis: 0 represents X axis, 1 represnets Y axis

        Returns:
            self
        """
        grasp_annos = self.grasp_annos
        flipped_grasp_annos = []
        if axis == 0:
            for grasp in grasp_annos:
                grasp[0::2] = self.width - grasp[0::2]
                flipped_grasp_annos.append(grasp)
        elif axis == 1:
            for grasp in grasp_annos:
                grasp[1::2] = self.height - grasp[1::2]
                flipped_grasp_annos.append(grasp)
        self.grasp_annos = flipped_grasp_annos
        return self