import cv2
import numpy as np
import os
import argparse
from typing import Union
from matplotlib import pyplot as plt


class ScalingSquareDetector:
    def __init__(self, feature_detector="ORB", debug=False):
        """
        Initialize the detector with the desired feature matching algorithm.
        :param feature_detector: "ORB" or "SIFT" (default is "ORB").
        :param debug: If True, saves intermediate images for debugging.
        """
        self.feature_detector = feature_detector
        self.debug = debug
        self.detector = self._initialize_detector()

    def _initialize_detector(self):
        """
        Initialize the chosen feature detector.
        :return: OpenCV detector object.
        """
        if self.feature_detector.upper() == "SIFT":
            return cv2.SIFT_create()
        elif self.feature_detector.upper() == "ORB":
            return cv2.ORB_create()
        else:
            raise ValueError("Invalid feature detector. Choose 'ORB' or 'SIFT'.")

    def find_scaling_square(
        self, reference_image_path, target_image, known_size_mm, roi_margin=30
    ):
        """
        Detect the scaling square in the target image based on the reference image.
        :param reference_image_path: Path to the reference image of the square.
        :param target_image_path: Path to the target image containing the square.
        :param known_size_mm: Physical size of the square in millimeters.
        :param roi_margin: Margin to expand the ROI around the detected square (in pixels).
        :return: Scaling factor (mm per pixel).
        """
        
        contours, _ = cv2.findContours(
            target_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
        )

        if not contours:
            raise ValueError("No contours found in the cropped ROI.")

        # # Select the largest square-like contour
        largest_square = None
        largest_square_area = 0
        for contour in contours:
            x_c, y_c, w_c, h_c = cv2.boundingRect(contour)
            aspect_ratio = w_c / float(h_c)
            if 0.9 <= aspect_ratio <= 1.1:
                peri = cv2.arcLength(contour, True)
                approx = cv2.approxPolyDP(contour, 0.02 * peri, True)
                if len(approx) == 4:
                    area = cv2.contourArea(contour)
                    if area > largest_square_area:
                        largest_square = contour
                        largest_square_area = area

        # if largest_square is None:
        #     raise ValueError("No square-like contour found in the ROI.")

        # Draw the largest contour on the original image
        target_image_color = cv2.cvtColor(target_image, cv2.COLOR_GRAY2BGR)
        cv2.drawContours(
            target_image_color, largest_square, -1, (255, 0, 0), 3
        )

        # if self.debug:
        cv2.imwrite("largest_contour.jpg", target_image_color)

        # Calculate the bounding rectangle of the largest contour
        x, y, w, h = cv2.boundingRect(largest_square)
        square_width_px = w
        square_height_px = h

        # Calculate the scaling factor
        avg_square_size_px = (square_width_px + square_height_px) / 2
        scaling_factor = 0.5 / avg_square_size_px  # mm per pixel

        return scaling_factor #, square_height_px, square_width_px, roi_binary

    def draw_debug_images(self, output_folder):
        """
        Save debug images if enabled.
        :param output_folder: Directory to save debug images.
        """
        if self.debug:
            if not os.path.exists(output_folder):
                os.makedirs(output_folder)
            debug_images = ["largest_contour.jpg"]
            for img_name in debug_images:
                if os.path.exists(img_name):
                    os.rename(img_name, os.path.join(output_folder, img_name))


def calculate_scaling_factor(
    reference_image_path,
    target_image,
    known_square_size_mm=12.7,
    feature_detector="ORB",
    debug=False,
    roi_margin=30,
):
    # Initialize detector
    detector = ScalingSquareDetector(feature_detector=feature_detector, debug=debug)

    # Find scaling square and calculate scaling factor
    scaling_factor = detector.find_scaling_square(
        reference_image_path=reference_image_path,
        target_image=target_image,
        known_size_mm=known_square_size_mm,
        roi_margin=roi_margin,
    )

    # Save debug images
    if debug:
        detector.draw_debug_images("debug_outputs")

    return scaling_factor


# Example usage:
if __name__ == "__main__":
    import os
    from PIL import Image
    from ultralytics import YOLO
    from app import yolo_detect, shrink_bbox
    from ultralytics.utils.plotting import save_one_box

    for idx, file in enumerate(os.listdir("./sample_images")):
        img = np.array(Image.open(os.path.join("./sample_images", file)))
        img = yolo_detect(img, ['box'])
        model = YOLO("./last.pt")
        res = model.predict(img, conf=0.6)
        
        box_img = save_one_box(res[0].cpu().boxes.xyxy, im=res[0].orig_img, save=False)
        # img = shrink_bbox(box_img, 1.20)
        cv2.imwrite(f"./outputs/{idx}_{file}", box_img)
        
        print("File: ",f"./outputs/{idx}_{file}")
        try:

            scaling_factor = calculate_scaling_factor(
                reference_image_path="./Reference_ScalingBox.jpg",
                target_image=box_img,
                known_square_size_mm=12.7,
                feature_detector="ORB",
                debug=False,
                roi_margin=90,
            )
            # cv2.imwrite(f"./outputs/{idx}_binary_{file}", roi_binary)
            
            # Square size in mm
            # square_size_mm = 12.7

            # # Compute the calculated scaling factors and compare
            # calculated_scaling_factor = square_size_mm / height_px
            # discrepancy = abs(calculated_scaling_factor - scaling_factor)
            # import pprint
            # pprint.pprint({
            #     "height_px": height_px,
            #     "width_px": width_px,
            #     "given_scaling_factor": scaling_factor,
            #     "calculated_scaling_factor": calculated_scaling_factor,
            #     "discrepancy": discrepancy,
            # })


            print(f"Scaling Factor (mm per pixel): {scaling_factor:.6f}")
        except Exception as e:
            from traceback import print_exc
            print(print_exc())
            print(f"Error: {e}")