import argparse

import cv2
import numpy as np
import torch

import kornia as K
from kornia.contrib import FaceDetector, FaceDetectorResult

import gradio as gr

import face_detection


def compare_detect_faces(img: np.ndarray, 
                         confidence_threshold, 
                         nms_threshold,
                         kornia_toggle,
                         retina_toggle,
                         retina_mobile_toggle,
                         dsfd_toggle
                         ):
    
    detections = []
    
    if kornia_toggle=="On":
        kornia_detections = kornia_detect(img, 
                                          confidence_threshold=confidence_threshold, 
                                          nms_threshold=nms_threshold)
    else:
        kornia_detections = None
        
    if retina_toggle=="On":
        retina_detections = retina_detect(img, 
                                          confidence_threshold=confidence_threshold, 
                                          nms_threshold=nms_threshold)
        detections.append(retina_detections)
    else:
        retina_detections = None
    
    if retina_mobile_toggle=="On":
        retina_mobile_detections = retina_mobilenet_detect(img, 
                                          confidence_threshold=confidence_threshold, 
                                          nms_threshold=nms_threshold)
        detections.append(retina_mobile_detections)
    else:
        retina_mobile_detections = None

    if dsfd_toggle=="On":
        dsfd_detections = dsfd_detect(img, 
                                          confidence_threshold=confidence_threshold, 
                                          nms_threshold=nms_threshold)
        detections.append(dsfd_detections)
    else:
        dsfd_detections = None

    
    return kornia_detections, retina_detections, retina_mobile_detections, dsfd_detections

def scale_image(img: np.ndarray, size: int) -> np.ndarray:
    h, w = img.shape[:2]
    scale = 1.0 * size / w
    return cv2.resize(img, (int(w * scale), int(h * scale)))


def base_detect(detector, img):
    img = scale_image(img, 640)

    detections = detector.detect(img)
    img_vis = img.copy()

    for box in detections:
        img_vis = cv2.rectangle(img_vis, 
                                box[:2].astype(int).tolist(), 
                                box[2:4].astype(int).tolist(), 
                                (0, 255, 0), 1)
    
    return img_vis


def retina_detect(img, confidence_threshold, nms_threshold):
    detector = face_detection.build_detector(
    "RetinaNetResNet50", confidence_threshold=confidence_threshold, nms_iou_threshold=nms_threshold)
    
    img_vis = base_detect(detector, img)
    
    return img_vis


def retina_mobilenet_detect(img, confidence_threshold, nms_threshold):
    detector = face_detection.build_detector(
    "RetinaNetMobileNetV1", confidence_threshold=confidence_threshold, nms_iou_threshold=nms_threshold)
    
    img_vis = base_detect(detector, img)
    
    return img_vis


def dsfd_detect(img, confidence_threshold, nms_threshold):
    detector = face_detection.build_detector(
    "DSFDDetector", confidence_threshold=confidence_threshold, nms_iou_threshold=nms_threshold)
    
    img_vis = base_detect(detector, img)
    
    return img_vis
    


def kornia_detect(img, confidence_threshold, nms_threshold):
    # select the device
    device = torch.device('cpu')

    # load the image and scale
    img_raw = scale_image(img, 400)

    # preprocess
    img = K.image_to_tensor(img_raw, keepdim=False).to(device)
    img = K.color.bgr_to_rgb(img.float())

    # create the detector and find the faces !
    face_detection = FaceDetector(confidence_threshold=confidence_threshold, 
                                  nms_threshold=nms_threshold).to(device)

    with torch.no_grad():
        dets = face_detection(img)
    dets = [FaceDetectorResult(o) for o in dets[0]]

    # show image

    img_vis = img_raw.copy()

    for b in dets:

        # draw face bounding box
        img_vis = cv2.rectangle(img_vis, 
                                b.top_left.int().tolist(), 
                                b.bottom_right.int().tolist(), 
                                (0, 255, 0), 
                                1)
    
    return img_vis
    
input_image = gr.components.Image()
image_kornia = gr.components.Image(label="Kornia YuNet")
image_retina = gr.components.Image(label="RetinaFace")
image_retina_mobile = gr.components.Image(label="Retina Mobilenet")
image_dsfd = gr.components.Image(label="DSFD")

confidence_slider = gr.components.Slider(minimum=0.1, maximum=0.95, value=0.5, step=0.05, label="Confidence Threshold")
nms_slider = gr.components.Slider(minimum=0.1, maximum=0.95, value=0.3, step=0.05, label="Non Maximum Supression (NMS) Threshold")
    
    
kornia_radio = gr.Radio(["On", "Off"], value="On", label="Kornia YuNet")
retinanet_radio = gr.Radio(["On", "Off"], value="On", label="RetinaFace")
retina_mobile_radio = gr.Radio(["On", "Off"], value="On", label="Retina Mobilenets")
dsfd_radio = gr.Radio(["On", "Off"], value="On", label="DSFD")

#methods_dropdown = gr.components.Dropdown(["Kornia YuNet", "RetinaFace", "RetinaMobile", "DSFD"], value="Kornia YuNet", label="Choose a method")

description = """This space let's you compare different face detection algorithms, based on Convolutional Neural Networks (CNNs).

The models used here are:
* Kornia YuNet: High Speed. Using the [Kornia Face Detection](https://kornia.readthedocs.io/en/latest/applications/face_detection.html) implementation
* RetinaFace: High Accuracy. Using the [RetinaFace](https://arxiv.org/pdf/1905.00641.pdf) implementation with ResNet50 backbone from the [face-detection library](https://github.com/hukkelas/DSFD-Pytorch-Inference)
* RetinaMobileNet: Mid Speed, Mid Accuracy. RetinaFace with a MobileNetV1 backbone, also from the [face-detection library](https://github.com/hukkelas/DSFD-Pytorch-Inference)
* DSFD: High Accuracy. [Dual Shot Face Detector](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_DSFD_Dual_Shot_Face_Detector_CVPR_2019_paper.pdf) from the [face-detection library](https://github.com/hukkelas/DSFD-Pytorch-Inference) as well.
"""

compare_iface = gr.Interface(
    fn=compare_detect_faces,
    inputs=[input_image, confidence_slider, nms_slider, kornia_radio, retinanet_radio, retina_mobile_radio, dsfd_radio],#, size_slider, neighbour_slider, scale_slider],
    outputs=[image_kornia, image_retina, image_retina_mobile, image_dsfd],
    examples=[["data/50_Celebration_Or_Party_birthdayparty_50_25.jpg", 0.5, 0.3, "On", "On", "On", "On"], 
              ["data/12_Group_Group_12_Group_Group_12_39.jpg", 0.5, 0.3, "On", "On", "On", "On"], 
              ["data/31_Waiter_Waitress_Waiter_Waitress_31_55.jpg", 0.5, 0.3, "On", "On", "On", "On"],
              ["data/12_Group_Group_12_Group_Group_12_283.jpg", 0.5, 0.3, "On", "On", "On", "On"]],
    title="Face Detections",
    description=description
).launch()