# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the sav_dataset directory of this source tree.
import json
import os
from typing import Dict, List, Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_util


def decode_video(video_path: str) -> List[np.ndarray]:
    """
    Decode the video and return the RGB frames
    """
    video = cv2.VideoCapture(video_path)
    video_frames = []
    while video.isOpened():
        ret, frame = video.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            video_frames.append(frame)
        else:
            break
    return video_frames


def show_anns(masks, colors: List, borders=True) -> None:
    """
    show the annotations
    """
    # return if no masks
    if len(masks) == 0:
        return

    # sort masks by size
    sorted_annot_and_color = sorted(
        zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True
    )
    H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1]

    canvas = np.ones((H, W, 4))
    canvas[:, :, 3] = 0  # set the alpha channel
    contour_thickness = max(1, int(min(5, 0.01 * min(H, W))))
    for mask, color in sorted_annot_and_color:
        canvas[mask] = np.concatenate([color, [0.55]])
        if borders:
            contours, _ = cv2.findContours(
                np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE
            )
            cv2.drawContours(
                canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness
            )

    ax = plt.gca()
    ax.imshow(canvas)


class SAVDataset:
    """
    SAVDataset is a class to load the SAV dataset and visualize the annotations.
    """

    def __init__(self, sav_dir, annot_sample_rate=4):
        """
        Args:
            sav_dir: the directory of the SAV dataset
            annot_sample_rate: the sampling rate of the annotations.
                The annotations are aligned with the videos at 6 fps.
        """
        self.sav_dir = sav_dir
        self.annot_sample_rate = annot_sample_rate
        self.manual_mask_colors = np.random.random((256, 3))
        self.auto_mask_colors = np.random.random((256, 3))

    def read_frames(self, mp4_path: str) -> None:
        """
        Read the frames and downsample them to align with the annotations.
        """
        if not os.path.exists(mp4_path):
            print(f"{mp4_path} doesn't exist.")
            return None
        else:
            # decode the video
            frames = decode_video(mp4_path)
            print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).")

            # downsample the frames to align with the annotations
            frames = frames[:: self.annot_sample_rate]
            print(
                f"Videos are annotated every {self.annot_sample_rate} frames. "
                "To align with the annotations, "
                f"downsample the video to {len(frames)} frames."
            )
            return frames

    def get_frames_and_annotations(
        self, video_id: str
    ) -> Tuple[List | None, Dict | None, Dict | None]:
        """
        Get the frames and annotations for video.
        """
        # load the video
        mp4_path = os.path.join(self.sav_dir, video_id + ".mp4")
        frames = self.read_frames(mp4_path)
        if frames is None:
            return None, None, None

        # load the manual annotations
        manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json")
        if not os.path.exists(manual_annot_path):
            print(f"{manual_annot_path} doesn't exist. Something might be wrong.")
            manual_annot = None
        else:
            manual_annot = json.load(open(manual_annot_path))

        # load the manual annotations
        auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json")
        if not os.path.exists(auto_annot_path):
            print(f"{auto_annot_path} doesn't exist.")
            auto_annot = None
        else:
            auto_annot = json.load(open(auto_annot_path))

        return frames, manual_annot, auto_annot

    def visualize_annotation(
        self,
        frames: List[np.ndarray],
        auto_annot: Optional[Dict],
        manual_annot: Optional[Dict],
        annotated_frame_id: int,
        show_auto=True,
        show_manual=True,
    ) -> None:
        """
        Visualize the annotations on the annotated_frame_id.
        If show_manual is True, show the manual annotations.
        If show_auto is True, show the auto annotations.
        By default, show both auto and manual annotations.
        """

        if annotated_frame_id >= len(frames):
            print("invalid annotated_frame_id")
            return

        rles = []
        colors = []
        if show_manual and manual_annot is not None:
            rles.extend(manual_annot["masklet"][annotated_frame_id])
            colors.extend(
                self.manual_mask_colors[
                    : len(manual_annot["masklet"][annotated_frame_id])
                ]
            )
        if show_auto and auto_annot is not None:
            rles.extend(auto_annot["masklet"][annotated_frame_id])
            colors.extend(
                self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])]
            )

        plt.imshow(frames[annotated_frame_id])

        if len(rles) > 0:
            masks = [mask_util.decode(rle) > 0 for rle in rles]
            show_anns(masks, colors)
        else:
            print("No annotation will be shown")

        plt.axis("off")
        plt.show()