from flask import Flask, request, jsonify, render_template, url_for
from flask_socketio import SocketIO
import threading
from ultralytics import YOLO
import numpy as np
import cv2
import matplotlib.pyplot as plt
import importlib
from segment_anything import sam_model_registry, SamPredictor
import os
from werkzeug.utils import secure_filename
import logging
import json
import shutil
import sys
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
app = Flask(__name__)
socketio = SocketIO(app)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configuration
class Config:
    BASE_DIR = os.path.abspath(os.path.dirname(__file__))
    UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static', 'uploads')
    SAM_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'sam','sam_results')
    YOLO_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','yolo_results')
    YOLO_TRAIN_IMAGE_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','images')
    YOLO_TRAIN_LABEL_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','labels')
    AREA_DATA_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','area_data')
    ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
    MAX_CONTENT_LENGTH = 16 * 1024 * 1024  # 16MB max file size
    SAM_2 = os.path.join(BASE_DIR, 'static', 'sam',"sam2.1_hiera_tiny.pt")
    YOLO_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_yolo.pt")
    RETRAINED_MODEL_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_retrained.pt")
    DATA_PATH = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo', "data.yaml")

app.config.from_object(Config)

# Ensure directories exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['SAM_RESULT_FOLDER'], exist_ok=True)
os.makedirs(app.config['YOLO_RESULT_FOLDER'], exist_ok=True)
os.makedirs(app.config['YOLO_TRAIN_IMAGE_FOLDER'], exist_ok=True)
os.makedirs(app.config['YOLO_TRAIN_LABEL_FOLDER'], exist_ok=True)
os.makedirs(app.config['AREA_DATA_FOLDER'], exist_ok=True)


# Initialize Yolo model
try:
    model = YOLO(app.config['YOLO_PATH'])
except Exception as e:
    logger.error(f"Failed to initialize YOLO model: {str(e)}")
    raise

try:
    sam2_checkpoint = app.config['SAM_2']
    model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"

    sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
    predictor = SAM2ImagePredictor(sam2_model)
except Exception as e:
    logger.error(f"Failed to initialize SAM model: {str(e)}")
    raise

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']

def scale_coordinates(coords, original_dims, target_dims):
    """
    Scale coordinates from one dimension space to another.
    
    Args:
        coords: List of [x, y] coordinates
        original_dims: Tuple of (width, height) of original space
        target_dims: Tuple of (width, height) of target space
    
    Returns:
        Scaled coordinates
    """
    scale_x = target_dims[0] / original_dims[0]
    scale_y = target_dims[1] / original_dims[1]
    
    return [
        [int(coord[0] * scale_x), int(coord[1] * scale_y)]
        for coord in coords
    ]

def scale_box(box, original_dims, target_dims):
    """
    Scale bounding box coordinates from one dimension space to another.
    
    Args:
        box: List of [x1, y1, x2, y2] coordinates
        original_dims: Tuple of (width, height) of original space
        target_dims: Tuple of (width, height) of target space
    
    Returns:
        Scaled box coordinates
    """
    scale_x = target_dims[0] / original_dims[0]
    scale_y = target_dims[1] / original_dims[1]
    
    return [
        int(box[0] * scale_x),  # x1
        int(box[1] * scale_y),  # y1
        int(box[2] * scale_x),  # x2
        int(box[3] * scale_y)   # y2
    ]

def retrain_model_fn():
    # Parameters for retraining
    data_path = app.config['DATA_PATH']
    epochs = 5
    img_size = 640
    batch_size = 8

    # Start training with YOLO, using event listeners for epoch completion
    for epoch in range(epochs):
        # Train the model for one epoch, here we simulate with a loop
        model.train(
            data=data_path,
            epochs=1,  # Use 1 epoch per call to get individual progress
            imgsz=img_size,
            batch=batch_size,
            device="cpu"  # Adjust based on system capabilities
        )

        # Emit an update to the client after each epoch
        socketio.emit('training_update', {
            'epoch': epoch + 1,
            'status': f"Epoch {epoch + 1} complete"
        })

    # Emit a message once training is complete
    socketio.emit('training_complete', {'status': "Retraining complete"})
    model.save(app.config['YOLO_PATH'])
    logger.info("Model retrained successfully")

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/yolo')
def yolo():
    return render_template('yolo.html')

@app.route('/upload_sam', methods=['POST'])
def upload_sam_file():
    """
    Handles SAM image upload and embeds the image into the predictor instance.

    Returns:
        JSON response with 'message', 'image_url', 'filename', and 'dimensions' keys
            on success, or 'error' key with an appropriate error message on failure.
    """

    try:
        if 'file' not in request.files:
            return jsonify({'error': 'No file part'}), 400
        
        file = request.files['file']
        if file.filename == '':
            return jsonify({'error': 'No selected file'}), 400
        
        if not allowed_file(file.filename):
            return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400
        
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)
        
        # Set the image for predictor right after upload
        image = cv2.imread(filepath)
        if image is None:
            return jsonify({'error': 'Failed to load uploaded image'}), 500
            
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        predictor.set_image(image)
        logger.info("Image embedded successfully")
        
        # Get image dimensions
        height, width = image.shape[:2]
        
        image_url = url_for('static', filename=f'uploads/{filename}')
        logger.info(f"File uploaded successfully: {filepath}")
        
        return jsonify({
            'message': 'File uploaded successfully',
            'image_url': image_url,
            'filename': filename,
            'dimensions': {
                'width': width,
                'height': height
            }
        })
        
    except Exception as e:
        logger.error(f"Upload error: {str(e)}")
        return jsonify({'error': 'Server error during upload'}), 500

@app.route('/upload_yolo', methods=['POST'])
def upload_yolo_file():
    """
    Upload a YOLO image file
    
    This endpoint allows a POST request containing a single image file. The file is
    saved to the uploads folder and the image is embedded into the YOLO model.
    
    Returns a JSON response with the following keys:
    - message: a success message
    - image_url: the URL of the uploaded image
    - filename: the name of the uploaded file
    
    If an error occurs, the JSON response will contain an 'error' key with a
    descriptive error message.
    """
    try:
        if 'file' not in request.files:
            return jsonify({'error': 'No file part'}), 400
        
        file = request.files['file']
        if file.filename == '':
            return jsonify({'error': 'No selected file'}), 400
        
        if not allowed_file(file.filename):
            return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400
        
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)
        
        
        image_url = url_for('static', filename=f'uploads/{filename}')
        logger.info(f"File uploaded successfully: {filepath}")
        
        return jsonify({
            'message': 'File uploaded successfully',
            'image_url': image_url,
            'filename': filename,
        })
        
    except Exception as e:
        logger.error(f"Upload error: {str(e)}")
        return jsonify({'error': 'Server error during upload'}), 500

@app.route('/generate_mask', methods=['POST'])
def generate_mask():
    """
    Generate a mask for a given image using the YOLO model
    @param data: a JSON object containing the following keys:
        - filename: the name of the image file
        - normalized_void_points: a list of normalized 2D points (x, y) representing the voids
        - normalized_component_boxes: a list of normalized 2D bounding boxes (x, y, w, h) representing the components
    @return: a JSON object containing the following keys:
        - status: a string indicating the status of the request
        - train_image_url: the URL of the saved train image
        - result_path: the URL of the saved result image
    """
    try:
        data = request.json
        normalized_void_points = data.get('void_points', [])
        normalized_component_boxes = data.get('component_boxes', [])
        filename = data.get('filename', '')
        
        if not filename:
            return jsonify({'error': 'No filename provided'}), 400
            
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        if not os.path.exists(image_path):
            return jsonify({'error': 'Image file not found'}), 404

        # Read image
        image = cv2.imread(image_path)
        if image is None:
            return jsonify({'error': 'Failed to load image'}), 500
            
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_height, image_width = image.shape[:2]
        
        # Denormalize coordinates back to pixel values
        void_points = [
            [int(point[0] * image_width), int(point[1] * image_height)]
            for point in normalized_void_points
        ]
        logger.info(f"Void points: {void_points}")
        
        component_boxes = [
            [
                int(box[0] * image_width),
                int(box[1] * image_height),
                int(box[2] * image_width),
                int(box[3] * image_height)
            ]
            for box in normalized_component_boxes
        ]
        logger.info(f"Void points: {void_points}")

        # Create a list to store individual void masks
        void_masks = []
        
        # Process void points one by one
        for point in void_points:
            # Convert point to correct format: [N, 2] array
            point_coord = np.array([[point[0], point[1]]])
            point_label = np.array([1])  # Single label
            
            masks, scores, _ = predictor.predict(
                point_coords=point_coord,
                point_labels=point_label,
                multimask_output=True  # Get multiple masks
            )
            
            if len(masks) > 0:  # Check if any masks were generated
                # Get the mask with highest score
                best_mask_idx = np.argmax(scores)
                void_masks.append(masks[best_mask_idx])
                logger.info(f"Processed void point {point} with score {scores[best_mask_idx]}")

        # Process component boxes
        component_masks = []
        if component_boxes:
            for box in component_boxes:
                # Convert box to correct format: [2, 2] array
                box_np = np.array([[box[0], box[1]], [box[2], box[3]]])
                masks, scores, _ = predictor.predict(
                    box=box_np,
                    multimask_output=True
                )
                if len(masks) > 0:
                    best_mask_idx = np.argmax(scores)
                    component_masks.append(masks[best_mask_idx])
                    logger.info(f"Processed component box {box}")

        # Create visualization with different colors for each void
        combined_image = image.copy()

        # Font settings for labels
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.6
        font_color = (0,0,0)  # White text color
        font_thickness = 1
        background_color = (255, 255, 255)  # White background for text

        # Helper function to get bounding box coordinates
        def get_bounding_box(mask):
            coords = np.column_stack(np.where(mask))
            x_min, y_min = coords.min(axis=0)
            x_max, y_max = coords.max(axis=0)
            return (x_min, y_min, x_max, y_max)
        
        # Helper function to add text with background
        def put_text_with_background(img, text, pos):
            # Calculate text size
            (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, font_thickness)
            # Define the rectangle coordinates for background
            background_tl = (pos[0], pos[1] - text_h - 2)
            background_br = (pos[0] + text_w, pos[1] + 2)
            # Draw white rectangle as background
            cv2.rectangle(img, background_tl, background_br, background_color, -1)
            # Put the text over the background rectangle
            cv2.putText(img, text, pos, font, font_scale, font_color, font_thickness, cv2.LINE_AA)

        def get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, img_width, img_height):
            # Default to top-right of bounding box
            x_pos = min(y_max, img_width - text_w - 10)  # Keep 10px margin from the right
            y_pos = max(x_min + text_h + 5, text_h + 5)  # Keep 5px margin from the top
            return x_pos, y_pos


        # Apply void masks with different colors
        for mask in void_masks:
            mask = mask.astype(bool)
            combined_image[mask, 0] = np.clip(0.5 * image[mask, 0] + 0.5 * 255, 0, 255)  # Red channel with transparency
            combined_image[mask, 1] = np.clip(0.5 * image[mask, 1], 0, 255)              # Green channel reduced
            combined_image[mask, 2] = np.clip(0.5 * image[mask, 2], 0, 255)
            logger.info("Mask Drawn")  

        # Apply component masks in green
        for mask in component_masks:
            mask = mask.astype(bool)
        # Only apply green where there is no red overlay
            non_red_area = mask & ~np.any([void_mask for void_mask in void_masks], axis=0)
            combined_image[non_red_area, 0] = np.clip(0.5 * image[non_red_area, 0], 0, 255)              # Reduced red channel
            combined_image[non_red_area, 1] = np.clip(0.5 * image[non_red_area, 1] + 0.5 * 255, 0, 255)  # Green channel
            combined_image[non_red_area, 2] = np.clip(0.5 * image[non_red_area, 2], 0, 255)
            logger.info("Mask Drawn") 


        # Add labels on top of masks
        for i,mask in enumerate(void_masks):
            x_min, y_min, x_max, y_max = get_bounding_box(mask)
            (text_w, text_h), _ = cv2.getTextSize("Void", font, font_scale, font_thickness)
            label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0])
            put_text_with_background(combined_image, f"Void {i+1}", label_position)    

        for i,mask in enumerate(component_masks):
            i=i+1
            x_min, y_min, x_max, y_max = get_bounding_box(mask)
            (text_w, text_h), _ = cv2.getTextSize("Component", font, font_scale, font_thickness)
            label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0])
            put_text_with_background(combined_image, f"Component {i}", label_position)

        # Prepare an empty list to store the output in the required format
        mask_coordinates = []

        for mask in void_masks:
            # Get contours from the mask
            contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            # Image dimensions
            height, width = mask.shape

            # For each contour, extract the normalized coordinates
            for contour in contours:
                contour_points = contour.reshape(-1, 2)  # Flatten to (N, 2) where N is the number of points
                normalized_points = contour_points / [width, height]  # Normalize to (0, 1)

                class_id = 1  # 1 for voids
                row = [class_id] + normalized_points.flatten().tolist()  # Flatten and add the class
                mask_coordinates.append(row)

        for mask in component_masks:
            # Get contours from the mask
            contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            # Filter to keep only the largest contour
            contours = sorted(contours, key=cv2.contourArea, reverse=True)
            largest_contour = [contours[0]] if contours else []
            # Image dimensions
            height, width = mask.shape

            # For each contour, extract the normalized coordinates
            for contour in largest_contour:
                contour_points = contour.reshape(-1, 2)  # Flatten to (N, 2) where N is the number of points
                normalized_points = contour_points / [width, height]  # Normalize to (0, 1)

                class_id = 0  # for components 
                row = [class_id] + normalized_points.flatten().tolist()  # Flatten and add the class
                mask_coordinates.append(row)

        mask_coordinates_filename = f'{filename}.txt'  # Create a unique filename
        mask_coordinates_path = os.path.join(app.config['YOLO_TRAIN_LABEL_FOLDER'], mask_coordinates_filename)


        with open(mask_coordinates_path, "w") as file:
            for row in mask_coordinates:
                # Join elements of the row into a string with spaces in between and write to the file
                file.write(" ".join(map(str, row)) + "\n")

        # Save train image
        train_image_filepath = os.path.join(app.config['YOLO_TRAIN_IMAGE_FOLDER'], filename)
        shutil.copy(image_path, train_image_filepath)
        train_image_url = url_for('static', filename=f'yolo/dataset_yolo/train/images/{filename}')

        # Save result
        result_filename = f'segmented_{filename}'
        result_path = os.path.join(app.config['SAM_RESULT_FOLDER'], result_filename)
        plt.imsave(result_path, combined_image)
        logger.info("Mask generation completed successfully")
        
        return jsonify({
            'status': 'success',
            'train_image_url':train_image_url,
            'result_path': url_for('static', filename=f'sam/sam_results/{result_filename}')
        })

    except Exception as e:
        logger.error(f"Mask generation error: {str(e)}")
        return jsonify({'error': str(e)}), 500

@app.route('/classify', methods=['POST'])
def classify():
    """
    Classify an image and return the classification result, area data, and the annotated image.

    Request body should contain a JSON object with a single key 'filename' specifying the image file to be classified.

    Returns a JSON object with the following keys:

    - status: 'success' if the classification is successful, 'error' if there is an error.
    - result_path: URL of the annotated image.
    - area_data: a list of dictionaries containing the area and overlap statistics for each component.
    - area_data_path: URL of the JSON file containing the area data.

    If there is an error, returns a JSON object with a single key 'error' containing the error message.
    """

    try:
        data = request.json
        filename = data.get('filename', '')
        if not filename:
            return jsonify({'error': 'No filename provided'}), 400
            
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        if not os.path.exists(image_path):
            return jsonify({'error': 'Image file not found'}), 404
        
        # Read image
        image = cv2.imread(image_path)
        if image is None:
            return jsonify({'error': 'Failed to load image'}), 500 

        results = model(image)
        result = results[0]

        component_masks = []
        void_masks = []

        # Extract masks and labels from results
        for mask, label in zip(result.masks.data, result.boxes.cls):
            mask_array = mask.cpu().numpy().astype(bool)  # Convert to a binary mask (boolean array)
            if label == 1:  # Assuming label '1' represents void
                void_masks.append(mask_array)
            elif label == 0:  # Assuming label '0' represents component
                component_masks.append(mask_array)

        # Calculate area and overlap statistics
        area_data = []
        for i, component_mask in enumerate(component_masks):
            component_area = np.sum(component_mask).item()  # Total component area in pixels
            void_area_within_component = 0
            max_void_area_percentage = 0
            
            # Calculate overlap of each void mask with the component mask
            for void_mask in void_masks:
                overlap_area = np.sum(void_mask & component_mask).item()  # Overlapping area
                void_area_within_component += overlap_area
                void_area_percentage = (overlap_area / component_area) * 100 if component_area > 0 else 0
                max_void_area_percentage = max(max_void_area_percentage, void_area_percentage)
            
            # Append data for this component
            area_data.append({
                "Image": filename,
                'Component': f'Component {i+1}',
                'Area': component_area,
                'Void Area (pixels)': void_area_within_component,
                'Void Area %': void_area_within_component / component_area * 100 if component_area > 0 else 0,
                'Max Void Area %': max_void_area_percentage
            })

        area_data_filename = f'area_data_{filename.split("/")[-1]}.json'  # Create a unique filename
        area_data_path = os.path.join(app.config['AREA_DATA_FOLDER'], area_data_filename)

        with open(area_data_path, 'w') as json_file:
            json.dump(area_data, json_file, indent=4)

        annotated_image = result.plot() 

        output_filename = f'output_{filename}'
        output_image_path = os.path.join(app.config['YOLO_RESULT_FOLDER'], output_filename)
        plt.imsave(output_image_path, annotated_image) 
        logger.info("Classification completed successfully")  

        return jsonify({
            'status': 'success',
            'result_path': url_for('static', filename=f'yolo/yolo_results/{output_filename}'),
            'area_data': area_data,
            'area_data_path': url_for('static', filename=f'yolo/area_data/{area_data_filename}')
        })
    except Exception as e:
        logger.error(f"Classification error: {str(e)}")
        return jsonify({'error': str(e)}), 500

retraining_status = {
    'status': 'idle',
    'progress': None,
    'message': None
}

@app.route('/start_retraining', methods=['GET', 'POST'])
def start_retraining():
    """
    Start the model retraining process.

    If the request is a POST, start the model retraining process in a separate thread.
    If the request is a GET, render the retraining page.

    Returns:
        A JSON response with the status of the retraining process, or a rendered HTML page.
    """
    if request.method == 'POST':
        # Reset status
        global retraining_status
        retraining_status['status'] = 'in_progress'
        retraining_status['progress'] = 'Initializing'
        
        # Start retraining in a separate thread
        threading.Thread(target=retrain_model_fn).start()
        return jsonify({'status': 'started'})
    else:
        # GET request - render the retraining page
        return render_template('retrain.html')

# Event handler for client connection
@socketio.on('connect')
def handle_connect():
    print('Client connected')


if __name__ == '__main__':
    app.run(port=5001, debug=True)