from transformers import DPTImageProcessor, DPTForDepthEstimation
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
import gradio as gr
import supervision as sv
import torch
import numpy as np
from PIL import Image
import requests

class DepthPredictor:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
        self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
        self.model.eval()
    
    def predict(self, image):
        # prepare image for the model
        encoding = self.feature_extractor(image, return_tensors="pt")
        
        # forward pass
        with torch.no_grad():
            outputs = self.model(**encoding)
            predicted_depth = outputs.predicted_depth
            
            # interpolate to original size
            prediction = torch.nn.functional.interpolate(
                                predicted_depth.unsqueeze(1),
                                size=image.size[::-1],
                                mode="bicubic",
                                align_corners=False,
                        ).squeeze()
            
        output = prediction.cpu().numpy()
        formatted = (output * 255 / np.max(output)).astype('uint8')
        #img = Image.fromarray(formatted)
        return formatted
    



class SegmentPredictor:
    def __init__(self):
        MODEL_TYPE = "vit_b"
        checkpoint = "sam_vit_b_01ec64.pth"
        sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
        # Select device
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        sam.to(device=self.device)
        self.mask_generator = SamAutomaticMaskGenerator(sam)
        self.conditioned_pred = SamPredictor(sam)
    
    def encode(self, image):
        image = np.array(image)
        self.conditioned_pred.set_image(image)
    
    def cond_pred(self, pts, lbls):
        masks, _, _ = self.conditioned_pred.predict(
            point_coords=pts,
            point_labels=lbls,
            multimask_output=True
            )
        return masks


    def segment_everything(self, image):
        image = np.array(image)
        sam_result = self.mask_generator.generate(image)
        mask_annotator = sv.MaskAnnotator()
        detections = sv.Detections.from_sam(sam_result=sam_result)
        annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
        return annotated_image