from PIL import Image
from torch import Tensor
from typing import List, Optional
import numpy as np
import torchvision
import json


def load_json(path: str):
    """
    Load json file from path and return the data
    :param path: Path to the json file
    :return:
    data: Data in the json file
    """
    with open(path, 'r') as f:
        data = json.load(f)
    return data


def save_json(data: dict, path: str):
    """
    Save data to a json file
    :param data: Data to be saved
    :param path: Path to save the data
    :return:
    """
    with open(path, "w") as f:
        json.dump(data, f)


def pil_loader(path):
    """
    Load image from path using PIL
    :param path: Path to the image
    :return:
    img: PIL Image
    """
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def get_dimensions(image: Tensor):
    """
    Get the dimensions of the image
    :param image: Tensor or PIL Image or np.ndarray
    :return:
    h: Height of the image
    w: Width of the image
    """
    if isinstance(image, Tensor):
        _, h, w = image.shape
    elif isinstance(image, np.ndarray):
        h, w, _ = image.shape
    elif isinstance(image, Image.Image):
        w, h = image.size
    else:
        raise ValueError(f"Invalid image type: {type(image)}")
    return h, w


def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
                          boxes: Optional[Tensor] = None, num_keypoints: int = 15):
    """
    Calculate the center crop parameters for the bounding boxes and landmarks and update them
    :param img: Image
    :param output_size: Output size of the cropped image
    :param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
    :param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
    :param num_keypoints: Number of keypoints
    :return:
    cropped_img: Center cropped image
    parts: Updated locations of the landmarks
    boxes: Updated bounding boxes of the landmarks
    """
    if isinstance(output_size, int):
        output_size = (output_size, output_size)
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
        output_size = (output_size[0], output_size[0])
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
        output_size = output_size
    else:
        raise ValueError(f"Invalid output size: {output_size}")

    crop_height, crop_width = output_size
    image_height, image_width = get_dimensions(img)
    img = torchvision.transforms.functional.center_crop(img, output_size)

    crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)

    if parts is not None:
        for j in range(num_keypoints):
            # Skip if part is invisible
            if parts[j][-1] == 0:
                continue
            parts[j][1] -= crop_left
            parts[j][2] -= crop_top

            # Skip if part is outside the crop
            if parts[j][1] > crop_width or parts[j][2] > crop_height:
                parts[j][-1] = 0
            if parts[j][1] < 0 or parts[j][2] < 0:
                parts[j][-1] = 0

            parts[j][1] = min(crop_width, parts[j][1])
            parts[j][2] = min(crop_height, parts[j][2])
            parts[j][1] = max(0, parts[j][1])
            parts[j][2] = max(0, parts[j][2])

    if boxes is not None:
        boxes[1] -= crop_left
        boxes[2] -= crop_top
        boxes[1] = max(0, boxes[1])
        boxes[2] = max(0, boxes[2])
        boxes[1] = min(crop_width, boxes[1])
        boxes[2] = min(crop_height, boxes[2])

    return img, parts, boxes


def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
    """
    Get the parameters for center cropping the image
    :param image_height: Height of the image
    :param image_width: Width of the image
    :param output_size: Output size of the cropped image
    :return:
    crop_top: Top coordinate of the cropped image
    crop_left: Left coordinate of the cropped image
    """
    if isinstance(output_size, int):
        output_size = (output_size, output_size)
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
        output_size = (output_size[0], output_size[0])
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
        output_size = output_size
    else:
        raise ValueError(f"Invalid output size: {output_size}")

    crop_height, crop_width = output_size

    if crop_width > image_width or crop_height > image_height:
        padding_ltrb = [
            (crop_width - image_width) // 2 if crop_width > image_width else 0,
            (crop_height - image_height) // 2 if crop_height > image_height else 0,
            (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
            (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
        ]
        crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
        return crop_top, crop_left

    if crop_width == image_width and crop_height == image_height:
        crop_top = 0
        crop_left = 0
        return crop_top, crop_left

    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))

    return crop_top, crop_left