import pandas as pd
import numpy as np
import streamlit as st 
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image
from streamlit_image_select import image_select
from tqdm import tqdm
import os
import shutil
from PIL import Image
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForMaskGeneration

def show_mask(image, mask, ax=None):
    fig, axes = plt.subplots()
    axes.imshow(np.array(image))
    color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    axes.imshow(mask_image)
    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    pil_image = Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb())
    plt.close(fig) 
    return pil_image 
def get_bounding_box(ground_truth_map):
  y_indices, x_indices = np.where(ground_truth_map > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  H, W = ground_truth_map.shape
  x_min = max(0, x_min - np.random.randint(0, 20))
  x_max = min(W, x_max + np.random.randint(0, 20))
  y_min = max(0, y_min - np.random.randint(0, 20))
  y_max = min(H, y_max + np.random.randint(0, 20))
  bbox = [x_min, y_min, x_max, y_max]
  return bbox
def get_output(image,prompt):
  inputs = processor(image,input_boxes=[[prompt]],return_tensors='pt').to(device)
  model.eval()
  with torch.no_grad():
    outputs = model(**inputs,multimask_output=False)
  output_proba = torch.sigmoid(outputs.pred_masks.squeeze(1))
  output_proba = output_proba.cpu().numpy().squeeze()
  output = (output_proba > 0.5).astype(np.uint8)
  return output
def generate_image(np_array):
  return Image.fromarray((np_array*255).astype('uint8'),mode='L')
def iou_calculation(result1, result2):
  intersection = np.logical_and(result1, result2)
  union = np.logical_or(result1, result2)
  iou_score = np.sum(intersection) / np.sum(union)
  iou_score = "{:.4f}".format(iou_score)
  return float(iou_score)
def calculate_pixel_accuracy(image1, image2):
    if image1.size != image2.size or image1.mode != image2.mode:
        image1 = image1.resize(image2.size, Image.BILINEAR)
        if image1.mode != image2.mode:
            image1 = image1.convert(image2.mode)
    width, height = image1.size
    total_pixels = width * height
    image1 = image1.convert("RGB")
    image2 = image2.convert("RGB")
    pixels1 = image1.load()
    pixels2 = image2.load()
    num_correct_pixels = 0
    for y in range(height):
        for x in range(width):
            if pixels1[x, y] == pixels2[x, y]:
                num_correct_pixels += 1
    accuracy = num_correct_pixels / total_pixels
    return accuracy
def calculate_f1_score(image1, image2):
    if image1.size != image2.size or image1.mode != image2.mode:
        image1 = image1.resize(image2.size, Image.BILINEAR)
        if image1.mode != image2.mode:
            image1 = image1.convert(image2.mode)
    width, height = image1.size
    image1 = image1.convert("L")
    image2 = image2.convert("L")
    np_image1 = np.array(image1)
    np_image2 = np.array(image2)
    np_image1_flat = np_image1.flatten()
    np_image2_flat = np_image2.flatten()
    true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
    false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
    false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
    precision = true_positives / (true_positives + false_positives + 1e-7)
    recall = true_positives / (true_positives + false_negatives + 1e-7)
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
    return f1_score
def calculate_dice_coefficient(image1, image2):
    if image1.size != image2.size or image1.mode != image2.mode:
        image1 = image1.resize(image2.size, Image.BILINEAR)
        if image1.mode != image2.mode:
            image1 = image1.convert(image2.mode)
    width, height = image1.size
    image1 = image1.convert("L")
    image2 = image2.convert("L")
    np_image1 = np.array(image1)
    np_image2 = np.array(image2)
    np_image1_flat = np_image1.flatten()
    np_image2_flat = np_image2.flatten()
    true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
    false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
    false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
    dice_coefficient = (2 * true_positives) / (2 * true_positives + false_positives + false_negatives)
    return dice_coefficient
device = "cuda" if torch.cuda.is_available() else "cpu"
st.set_page_config(layout='wide')
ds = load_dataset('ahishamm/combined_masks',split='train')
s1 = ds[3]['image']
s2 = ds[4]['image']
s3 = ds[5]['image']
s4 = ds[6]['image']
s1_label = ds[3]['label']
s2_label = ds[4]['label']
s3_label = ds[5]['label']
s4_label = ds[6]['label']
image_arr = [s1,s2,s3,s4]
label_arr = [s1_label,s2_label,s3_label,s4_label]
img = image_select(
    label="Select a Skin Lesion Image",
    images=[
        s1,s2,s3,s4
    ],
    captions=["sample 1","sample 2","sample 3","sample 4"],
    return_value='index'
)
processor = AutoProcessor.from_pretrained('ahishamm/skinsam')
model = AutoModelForMaskGeneration.from_pretrained('ahishamm/skinsam_focalloss_base_combined')
model.to(device)
p = get_bounding_box(np.array(label_arr[img])) 
predicted_mask_array = get_output(image_arr[img],p)
predicted_mask = generate_image(predicted_mask_array)
result_image = show_mask(image_arr[img],predicted_mask_array)
with st.container(): 
    tab1, tab2 = st.tabs(['Visualizations','Metrics'])
    with tab1: 
        col1, col2 = st.columns(2) 
        with col1: 
            st.image(image_arr[img],caption='Original Skin Lesion Image',use_column_width=True)
        with col2:
                st.image(result_image,caption='Mask Overlay',use_column_width=True)
    with tab2: 
            st.write(f'The IOU Score: {iou_calculation(label_arr[img],predicted_mask)}')
            st.write(f'The Pixel Accuracy: {calculate_pixel_accuracy(label_arr[img],predicted_mask)}')
            st.write(f'The Dice Coefficient Score: {calculate_dice_coefficient(label_arr[img],predicted_mask)}')