# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn
from torch.nn import functional as F
import torchvision

from typing import Any, Dict, List, Tuple

from .sam.image_encoder import ImageEncoderViT
from .sam.mask_decoder import MaskDecoder
from .sam.prompt_encoder import PromptEncoder
from .detr import box_ops
from .detr.segmentation import dice_loss, sigmoid_focal_loss
from .detr.misc import nested_tensor_from_tensor_list, interpolate
from . import axis_ops, ilnr_loss #, pwnp_loss
from .vnl_loss import VNL_Loss
from .midas_loss import MidasLoss


class SamTransformer(nn.Module):
    mask_threshold: float = 0.0
    image_format: str = "RGB"

    def __init__(
        self,
        image_encoder: ImageEncoderViT,
        prompt_encoder: PromptEncoder,
        mask_decoder: MaskDecoder,
        affordance_decoder: MaskDecoder,
        depth_decoder: MaskDecoder,
        transformer_hidden_dim: int,
        backbone_name: str,
        pixel_mean: List[float] = [123.675, 116.28, 103.53],
        pixel_std: List[float] = [58.395, 57.12, 57.375],
    ) -> None:
        """
        SAM predicts object masks from an image and input prompts.

        Arguments:
          image_encoder (ImageEncoderViT): The backbone used to encode the
            image into image embeddings that allow for efficient mask prediction.
          prompt_encoder (PromptEncoder): Encodes various types of input prompts.
          mask_decoder (MaskDecoder): Predicts masks from the image embeddings
            and encoded prompts.
          pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
          pixel_std (list(float)): Std values for normalizing pixels in the input image.
        """
        super().__init__()
        self.image_encoder = image_encoder
        self.prompt_encoder = prompt_encoder
        self.mask_decoder = mask_decoder
        self.affordance_decoder = affordance_decoder
        
        # depth head
        self.depth_decoder = depth_decoder
        self.depth_query = nn.Embedding(2, transformer_hidden_dim)
        fov = torch.tensor(1.0)
        image_size = (768, 1024)
        focal_length = (image_size[1] / 2 / torch.tan(fov / 2)).item()
        self.vnl_loss = VNL_Loss(focal_length, focal_length, image_size)
        self.midas_loss = MidasLoss(alpha=0.1)

        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

        # if backbone_name == 'vit_h':
        #     checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_h_4b8939.pth')
        # elif backbone_name == 'vit_l':
        #     checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_l_0b3195.pth')
        # elif backbone_name == 'vit_b':
        #     checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_b_01ec64.pth')
        # else:
        #     raise ValueError

        # with open(checkpoint_path, "rb") as f:
        #     state_dict = torch.load(f)
        # self.load_state_dict(state_dict, strict=False)

        # self.affordance_decoder.load_state_dict(self.mask_decoder.state_dict(), strict=False)
        # self.depth_decoder.load_state_dict(self.mask_decoder.state_dict(), strict=False)

        self.num_queries = 15
        self._affordance_focal_alpha = 0.95
        self._ignore_index = -100

    @property
    def device(self) -> Any:
        return self.pixel_mean.device
    
    def freeze_layers(self, names):
        """
        Freeze layers in 'names'.
        """
        for name, param in self.named_parameters():
            for freeze_name in names:
                if freeze_name in name:
                    param.requires_grad = False
    
    def forward(
        self,
        image: torch.Tensor,
        valid: torch.Tensor,
        keypoints: torch.Tensor,
        bbox: torch.Tensor,
        masks: torch.Tensor,
        movable: torch.Tensor,
        rigid: torch.Tensor,
        kinematic: torch.Tensor,
        action: torch.Tensor,
        affordance: torch.Tensor,
        affordance_map: torch.FloatTensor,
        depth: torch.Tensor,
        axis: torch.Tensor,
        fov: torch.Tensor,
        backward: bool = True,
        **kwargs,
    ):
        device = image.device
        multimask_output = False
        
        # image encoder
        # pad image to square
        h, w = image.shape[-2:]
        padh = self.image_encoder.img_size - h
        padw = self.image_encoder.img_size - w
        x = F.pad(image, (0, padw, 0, padh))
        image_embeddings = self.image_encoder(x)

        outputs_seg_masks = []
        outputs_movable = []
        outputs_rigid = []
        outputs_kinematic = []
        outputs_action = []
        outputs_axis = []
        outputs_boxes = []
        outputs_aff_masks = []
        outputs_depth = []
        for idx, curr_embedding in enumerate(image_embeddings):
            point_coords = keypoints[idx].unsqueeze(1)
            point_labels = torch.ones_like(point_coords[:, :, 0])
            points = (point_coords, point_labels)
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=points,
                boxes=None,
                masks=None,
            )
            
            # mask decoder
            low_res_masks, iou_predictions, output_movable, output_rigid, output_kinematic, output_action, output_axis = self.mask_decoder(
                image_embeddings=curr_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_output,
            )
            output_mask = self.postprocess_masks(
                low_res_masks,
                input_size=image.shape[-2:],
                original_size=(768, 1024),
            )
            outputs_seg_masks.append(output_mask[:, 0])
            outputs_movable.append(output_movable[:, 0])
            outputs_rigid.append(output_rigid[:, 0])
            outputs_kinematic.append(output_kinematic[:, 0])
            outputs_action.append(output_action[:, 0])
            outputs_axis.append(output_axis[:, 0])

            # convert masks to boxes for evaluation
            pred_mask_bbox = (output_mask[:, 0].clone() > 0.0).long() 
            empty_mask = pred_mask_bbox.sum(dim=-1).sum(dim=-1)
            pred_mask_bbox[empty_mask == 0] += 1
            pred_boxes = torchvision.ops.masks_to_boxes(pred_mask_bbox)
            #pred_boxes = box_ops.rescale_bboxes(pred_boxes, [1 / self._image_size[1], 1 / self._image_size[0]])
            pred_boxes = box_ops.rescale_bboxes(pred_boxes, [1 / 768, 1 / 1024])
            outputs_boxes.append(pred_boxes)

            # affordance decoder
            low_res_masks, iou_predictions = self.affordance_decoder(
                image_embeddings=curr_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_output,
            )
            output_aff_masks = self.postprocess_masks(
                low_res_masks,
                input_size=image.shape[-2:],
                original_size=(192, 256),
            )
            outputs_aff_masks.append(output_aff_masks[:, 0])

            # depth decoder
            bs = keypoints.shape[0]
            #depth_sparse_embeddings = self.depth_query.weight.unsqueeze(0).repeat(bs, 1, 1)
            depth_sparse_embeddings = self.depth_query.weight.unsqueeze(0)
            #depth_dense_embeddings = torch.zeros((bs, 256, 64, 64)).to(dense_embeddings.device)
            depth_dense_embeddings = torch.zeros((1, 256, 64, 64)).to(dense_embeddings.device)
            low_res_masks, iou_predictions = self.depth_decoder(
                image_embeddings=curr_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=depth_sparse_embeddings,
                dense_prompt_embeddings=depth_dense_embeddings,
                multimask_output=multimask_output,
            )
            output_depth = self.postprocess_masks(
                low_res_masks,
                input_size=image.shape[-2:],
                original_size=(768, 1024),
            )
            outputs_depth.append(output_depth[:, 0])
        
        outputs_seg_masks = torch.stack(outputs_seg_masks)
        outputs_movable = torch.stack(outputs_movable)
        outputs_rigid = torch.stack(outputs_rigid)
        outputs_kinematic = torch.stack(outputs_kinematic)
        outputs_action = torch.stack(outputs_action)
        outputs_axis = torch.stack(outputs_axis)
        outputs_boxes = torch.stack(outputs_boxes)
        outputs_aff_masks = torch.stack(outputs_aff_masks)
        outputs_depth = torch.stack(outputs_depth)

 
        out = {
            'pred_boxes': outputs_boxes,
            'pred_movable': outputs_movable,
            'pred_rigid': outputs_rigid,
            'pred_kinematic': outputs_kinematic,
            'pred_action': outputs_action,
            'pred_masks': outputs_seg_masks,
            'pred_axis': outputs_axis,
            'pred_depth': outputs_depth,
            # 'pred_depth': outputs_seg_masks[:, :1].sigmoid(),
            'pred_affordance': outputs_aff_masks,
        }

        if not backward:
            return out

        # backward
        src_boxes = outputs_boxes
        target_boxes = bbox
        target_boxes = box_ops.box_xyxy_to_cxcywh(target_boxes)
        bbox_valid = bbox[:, :, 0] > -0.5
        num_boxes = bbox_valid.sum()

        out['loss_bbox'] = torch.tensor(0.0, requires_grad=True).to(device)
        out['loss_giou'] = torch.tensor(0.0, requires_grad=True).to(device)

        # affordance
        # out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device)
        affordance_valid = affordance[:, :, 0] > -0.5
        if affordance_valid.sum() == 0:
            out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device)
        else:
            src_aff_masks = outputs_aff_masks[affordance_valid]
            tgt_aff_masks = affordance_map[affordance_valid]
            src_aff_masks = src_aff_masks.flatten(1)
            tgt_aff_masks = tgt_aff_masks.flatten(1)
            loss_aff = sigmoid_focal_loss(
                src_aff_masks,
                tgt_aff_masks,
                affordance_valid.sum(),
                alpha=self._affordance_focal_alpha,
            )
            out['loss_affordance'] = loss_aff

        # axis
        axis_valid = axis[:, :, 0] > 0.0
        num_axis = axis_valid.sum()
        if num_axis == 0:
            out['loss_axis_angle'] = torch.tensor(0.0, requires_grad=True).to(device)
            out['loss_axis_offset'] = torch.tensor(0.0, requires_grad=True).to(device)
            out['loss_eascore'] = torch.tensor(0.0, requires_grad=True).to(device)
        else:
            # regress angle
            src_axis_angle = outputs_axis[axis_valid]
            src_axis_angle_norm = F.normalize(src_axis_angle[:, :2])
            src_axis_angle = torch.cat((src_axis_angle_norm, src_axis_angle[:, 2:]), dim=-1)
            target_axis_xyxy = axis[axis_valid]
            
            axis_center = target_boxes[axis_valid].clone()
            axis_center[:, 2:] = axis_center[:, :2] 
            target_axis_angle = axis_ops.line_xyxy_to_angle(target_axis_xyxy, center=axis_center)

            loss_axis_angle = F.l1_loss(src_axis_angle[:, :2], target_axis_angle[:, :2], reduction='sum') / num_axis
            loss_axis_offset = F.l1_loss(src_axis_angle[:, 2:], target_axis_angle[:, 2:], reduction='sum') / num_axis
            out['loss_axis_angle'] = loss_axis_angle
            out['loss_axis_offset'] = loss_axis_offset
            
            src_axis_xyxy = axis_ops.line_angle_to_xyxy(src_axis_angle, center=axis_center)
            target_axis_xyxy = axis_ops.line_angle_to_xyxy(target_axis_angle, center=axis_center)

            axis_eascore, _, _ = axis_ops.ea_score(src_axis_xyxy, target_axis_xyxy)
            loss_eascore = 1 - axis_eascore
            out['loss_eascore'] = loss_eascore.mean()


        loss_movable = F.cross_entropy(outputs_movable.permute(0, 2, 1), movable, ignore_index=self._ignore_index)
        if torch.isnan(loss_movable):
            loss_movable = torch.tensor(0.0, requires_grad=True).to(device)
        out['loss_movable'] = loss_movable

        loss_rigid = F.cross_entropy(outputs_rigid.permute(0, 2, 1), rigid, ignore_index=self._ignore_index)
        if torch.isnan(loss_rigid):
            loss_rigid = torch.tensor(0.0, requires_grad=True).to(device)
        out['loss_rigid'] = loss_rigid

        loss_kinematic = F.cross_entropy(outputs_kinematic.permute(0, 2, 1), kinematic, ignore_index=self._ignore_index)
        if torch.isnan(loss_kinematic):
            loss_kinematic = torch.tensor(0.0, requires_grad=True).to(device)
        out['loss_kinematic'] = loss_kinematic

        loss_action = F.cross_entropy(outputs_action.permute(0, 2, 1), action, ignore_index=self._ignore_index)
        if torch.isnan(loss_action):
            loss_action = torch.tensor(0.0, requires_grad=True).to(device)
        out['loss_action'] = loss_action

        # depth backward
        out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device)
        out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device)
        # (bs, 1, H, W)
        src_depths = interpolate(outputs_depth, size=depth.shape[-2:], mode='bilinear', align_corners=False)
        src_depths = src_depths.clamp(min=0.0, max=1.0)
        tgt_depths = depth.unsqueeze(1)  # (bs, H, W)
        valid_depth = depth[:, 0, 0] > 0
        if valid_depth.any():
            src_depths = src_depths[valid_depth]
            tgt_depths = tgt_depths[valid_depth]
            depth_mask = tgt_depths > 1e-8
            midas_loss, ssi_loss, reg_loss = self.midas_loss(src_depths, tgt_depths, depth_mask)
            loss_vnl = self.vnl_loss(tgt_depths, src_depths)
            out['loss_depth'] = midas_loss
            out['loss_vnl'] = loss_vnl
        else:
            out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device)
            out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device)
 
        # mask backward
        tgt_masks = masks
        src_masks = interpolate(outputs_seg_masks, size=tgt_masks.shape[-2:], mode='bilinear', align_corners=False)
        valid_mask = tgt_masks.sum(dim=-1).sum(dim=-1) > 10
        if valid_mask.sum() == 0:
            out['loss_mask'] = torch.tensor(0.0, requires_grad=True).to(device)
            out['loss_dice'] = torch.tensor(0.0, requires_grad=True).to(device)
        else:
            num_masks = valid_mask.sum()
            src_masks = src_masks[valid_mask]
            tgt_masks = tgt_masks[valid_mask]
            src_masks = src_masks.flatten(1)
            tgt_masks = tgt_masks.flatten(1)
            tgt_masks = tgt_masks.view(src_masks.shape)
            out['loss_mask'] = sigmoid_focal_loss(src_masks, tgt_masks.float(), num_masks)
            out['loss_dice'] = dice_loss(src_masks, tgt_masks, num_masks)

        return out

    def postprocess_masks(
        self,
        masks: torch.Tensor,
        input_size: Tuple[int, ...],
        original_size: Tuple[int, ...],
    ) -> torch.Tensor:
        """
        Remove padding and upscale masks to the original image size.

        Arguments:
          masks (torch.Tensor): Batched masks from the mask_decoder,
            in BxCxHxW format.
          input_size (tuple(int, int)): The size of the image input to the
            model, in (H, W) format. Used to remove padding.
          original_size (tuple(int, int)): The original size of the image
            before resizing for input to the model, in (H, W) format.

        Returns:
          (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
            is given by original_size.
        """
        masks = F.interpolate(
            masks,
            (self.image_encoder.img_size, self.image_encoder.img_size),
            mode="bilinear",
            align_corners=False,
        )
        masks = masks[..., : input_size[0], : input_size[1]]
        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
        return masks

    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
        """Normalize pixel values and pad to a square input."""
        # Normalize colors
        x = (x - self.pixel_mean) / self.pixel_std

        # Pad
        h, w = x.shape[-2:]
        padh = self.image_encoder.img_size - h
        padw = self.image_encoder.img_size - w
        x = F.pad(x, (0, padw, 0, padh))
        return x