import gradio as gr
import numpy as np
import cv2
from PIL import Image
from ultralytics import YOLO
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
# Define available YOLO models
available_models = {
    "X-ray": YOLO("xray.pt"),
    "CT scan": YOLO("CT.pt"),
    "Ultrasound": YOLO("ultrasound.pt"),
    # Add more models as needed
}

def segment_image(input_image, selected_model):
    # Resize the input image to 255x255
    img = np.array(input_image)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    model = available_models[selected_model]
    
    # Perform object detection and segmentation
    results = model(img)
    mask = results[0].masks.data.numpy()
    target_height = img.shape[0]
    target_width = img.shape[1]

    # Resize the mask using OpenCV
    resized_mask = cv2.resize(mask[0], (target_width, target_height))
    resized_mask = (resized_mask * 255).astype(np.uint8)

    # Create a copy of the original image
    overlay_image = img.copy()

    # Apply the resized mask to the overlay image
    overlay_image[resized_mask > 0] = [0, 255, 255]  # Overlay in green

    # Convert the overlay image to PIL format
    overlay_pil = Image.fromarray(overlay_image)

    # Convert the resized mask to PIL format
    mask_pil = Image.fromarray(resized_mask)
    
    for result in results:
        boxes = result.boxes
        bbox = boxes.xyxy.tolist()[0]
    

    sam_checkpoint = "sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device='cpu')
    predictor = SamPredictor(sam)
    predictor.set_image(img)
    
    input_box = np.array(bbox)
    masks_, _, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box,
        multimask_output=False)

    fmask = masks_[0].astype(int)

    resized_mask1 = cv2.resize(fmask, (target_width, target_height))
    resized_mask1 = (resized_mask1 * 255).astype(np.uint8)

    overlay_image1 = img.copy()
    # Apply the resized mask to the overlay image
    overlay_image1[resized_mask1 > 0] = [255, 255, 0]  # Overlay in green

    # Convert the overlay image to PIL format
    overlay_pil1 = Image.fromarray(overlay_image1)

    return overlay_pil, overlay_pil1  # Return both overlay image and mask

# Create the Gradio interface with a dropdown for model selection
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.components.Image(type="pil", label="Upload an image"),
        gr.components.Dropdown(
            choices=list(available_models.keys()),
            label="Select YOLO Model",
            default="X-ray"
        )
    ],
    outputs=[
        gr.components.Image(type="pil", label="YOLO predicted mask and images"),
        gr.components.Image(type="pil", label="YOLO and SAM predicted mask and images ")
    ],
    title="YOLOv8 with SAM 😃",
    description='This software generates the segmentation mask Medical images'
)

iface.launch()