import numpy as np
import torch

from ..metrics import ap_per_class


def fitness(x):
    # Model fitness as a weighted combination of metrics
    w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9, 0.1, 0.9]
    return (x[:, :len(w)] * w).sum(1)


def ap_per_class_box_and_mask(
        tp_m,
        tp_b,
        conf,
        pred_cls,
        target_cls,
        plot=False,
        save_dir=".",
        names=(),
):
    """
    Args:
        tp_b: tp of boxes.
        tp_m: tp of masks.
        other arguments see `func: ap_per_class`.
    """
    results_boxes = ap_per_class(tp_b,
                                 conf,
                                 pred_cls,
                                 target_cls,
                                 plot=plot,
                                 save_dir=save_dir,
                                 names=names,
                                 prefix="Box")[2:]
    results_masks = ap_per_class(tp_m,
                                 conf,
                                 pred_cls,
                                 target_cls,
                                 plot=plot,
                                 save_dir=save_dir,
                                 names=names,
                                 prefix="Mask")[2:]

    results = {
        "boxes": {
            "p": results_boxes[0],
            "r": results_boxes[1],
            "ap": results_boxes[3],
            "f1": results_boxes[2],
            "ap_class": results_boxes[4]},
        "masks": {
            "p": results_masks[0],
            "r": results_masks[1],
            "ap": results_masks[3],
            "f1": results_masks[2],
            "ap_class": results_masks[4]}}
    return results


class Metric:

    def __init__(self) -> None:
        self.p = []  # (nc, )
        self.r = []  # (nc, )
        self.f1 = []  # (nc, )
        self.all_ap = []  # (nc, 10)
        self.ap_class_index = []  # (nc, )

    @property
    def ap50(self):
        """AP@0.5 of all classes.
        Return:
            (nc, ) or [].
        """
        return self.all_ap[:, 0] if len(self.all_ap) else []

    @property
    def ap(self):
        """AP@0.5:0.95
        Return:
            (nc, ) or [].
        """
        return self.all_ap.mean(1) if len(self.all_ap) else []

    @property
    def mp(self):
        """mean precision of all classes.
        Return:
            float.
        """
        return self.p.mean() if len(self.p) else 0.0

    @property
    def mr(self):
        """mean recall of all classes.
        Return:
            float.
        """
        return self.r.mean() if len(self.r) else 0.0

    @property
    def map50(self):
        """Mean AP@0.5 of all classes.
        Return:
            float.
        """
        return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0

    @property
    def map(self):
        """Mean AP@0.5:0.95 of all classes.
        Return:
            float.
        """
        return self.all_ap.mean() if len(self.all_ap) else 0.0

    def mean_results(self):
        """Mean of results, return mp, mr, map50, map"""
        return (self.mp, self.mr, self.map50, self.map)

    def class_result(self, i):
        """class-aware result, return p[i], r[i], ap50[i], ap[i]"""
        return (self.p[i], self.r[i], self.ap50[i], self.ap[i])

    def get_maps(self, nc):
        maps = np.zeros(nc) + self.map
        for i, c in enumerate(self.ap_class_index):
            maps[c] = self.ap[i]
        return maps

    def update(self, results):
        """
        Args:
            results: tuple(p, r, ap, f1, ap_class)
        """
        p, r, all_ap, f1, ap_class_index = results
        self.p = p
        self.r = r
        self.all_ap = all_ap
        self.f1 = f1
        self.ap_class_index = ap_class_index


class Metrics:
    """Metric for boxes and masks."""

    def __init__(self) -> None:
        self.metric_box = Metric()
        self.metric_mask = Metric()

    def update(self, results):
        """
        Args:
            results: Dict{'boxes': Dict{}, 'masks': Dict{}}
        """
        self.metric_box.update(list(results["boxes"].values()))
        self.metric_mask.update(list(results["masks"].values()))

    def mean_results(self):
        return self.metric_box.mean_results() + self.metric_mask.mean_results()

    def class_result(self, i):
        return self.metric_box.class_result(i) + self.metric_mask.class_result(i)

    def get_maps(self, nc):
        return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc)

    @property
    def ap_class_index(self):
        # boxes and masks have the same ap_class_index
        return self.metric_box.ap_class_index


class Semantic_Metrics:
    def __init__(self, nc, device):
        self.nc = nc  # number of classes
        self.device = device
        self.iou = []
        self.c_bit_counts = torch.zeros(nc, dtype = torch.long).to(device)
        self.c_intersection_counts = torch.zeros(nc, dtype = torch.long).to(device)
        self.c_union_counts = torch.zeros(nc, dtype = torch.long).to(device)

    def update(self, pred_masks, target_masks):
        nb, nc, h, w = pred_masks.shape
        device = pred_masks.device

        for b in range(nb):
            onehot_mask = pred_masks[b].to(device)
            # convert predict mask to one hot
            semantic_mask = torch.flatten(onehot_mask, start_dim = 1).permute(1, 0) # class x h x w -> (h x w) x class
            max_idx = semantic_mask.argmax(1)
            output_masks = (torch.zeros(semantic_mask.shape).to(self.device)).scatter(1, max_idx.unsqueeze(1), 1.0) # one hot: (h x w) x class
            output_masks = torch.reshape(output_masks.permute(1, 0), (nc, h, w)) # (h x w) x class -> class x h x w
            onehot_mask = output_masks.int()

            for c in range(self.nc):
                pred_mask = onehot_mask[c].to(device)
                target_mask = target_masks[b, c].to(device)

                # calculate IoU
                intersection = (torch.logical_and(pred_mask, target_mask).sum()).item()
                union = (torch.logical_or(pred_mask, target_mask).sum()).item()
                iou = 0. if (0 == union) else (intersection / union)

                # record class pixel counts, intersection counts, union counts
                self.c_bit_counts[c] += target_mask.int().sum()
                self.c_intersection_counts[c] += intersection
                self.c_union_counts[c] += union

                self.iou.append(iou)

    def results(self):
        # Mean IoU
        miou = 0. if (0 == len(self.iou)) else np.sum(self.iou) / (len(self.iou) * self.nc)

        # Frequency Weighted IoU
        c_iou = self.c_intersection_counts / (self.c_union_counts + 1)  # add smooth
        # c_bit_counts = self.c_bit_counts.astype(int)
        total_c_bit_counts = self.c_bit_counts.sum()
        freq_ious = torch.zeros(1, dtype = torch.long).to(self.device) if (0 == total_c_bit_counts) else (self.c_bit_counts / total_c_bit_counts) * c_iou
        fwiou = (freq_ious.sum()).item()

        return (miou, fwiou)

    def reset(self):
        self.iou = []
        self.c_bit_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device)
        self.c_intersection_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device)
        self.c_union_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device)


KEYS = [
    "train/box_loss",
    "train/seg_loss",  # train loss
    "train/cls_loss",
    "train/dfl_loss",
    "train/fcl_loss",
    "train/dic_loss",
    "metrics/precision(B)",
    "metrics/recall(B)",
    "metrics/mAP_0.5(B)",
    "metrics/mAP_0.5:0.95(B)",  # metrics
    "metrics/precision(M)",
    "metrics/recall(M)",
    "metrics/mAP_0.5(M)",
    "metrics/mAP_0.5:0.95(M)",  # metrics
    "metrics/MIOUS(S)",
    "metrics/FWIOUS(S)",        # metrics
    "val/box_loss",
    "val/seg_loss",  # val loss
    "val/cls_loss",
    "val/dfl_loss",
    "val/fcl_loss",
    "val/dic_loss",
    "x/lr0",
    "x/lr1",
    "x/lr2",]

BEST_KEYS = [
    "best/epoch",
    "best/precision(B)",
    "best/recall(B)",
    "best/mAP_0.5(B)",
    "best/mAP_0.5:0.95(B)",
    "best/precision(M)",
    "best/recall(M)",
    "best/mAP_0.5(M)",
    "best/mAP_0.5:0.95(M)",
    "best/MIOUS(S)",
    "best/FWIOUS(S)",]