from transformers import DPTImageProcessor, DPTForDepthEstimation
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
import gradio as gr
import supervision as sv
import torch
import numpy as np
from PIL import Image
import requests
import open3d as o3d
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt

class DepthPredictor:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
        self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
        self.model.eval()
    
    def predict(self, image):
        # prepare image for the model
        encoding = self.feature_extractor(image, return_tensors="pt")
        # forward pass
        with torch.no_grad():
            outputs = self.model(**encoding)
            predicted_depth = outputs.predicted_depth
            # interpolate to original size
            prediction = torch.nn.functional.interpolate(
                                predicted_depth.unsqueeze(1),
                                size=image.size[::-1],
                                mode="bicubic",
                                align_corners=False,
                        ).squeeze()
            
        output = prediction.cpu().numpy()
        formatted = (output * 255 / np.max(output)).astype('uint8')
        #img = Image.fromarray(formatted)
        return output
    
    def generate_pcl(self, image):
        print(np.array(image).shape)
        depth = self.predict(image)
        print(depth.shape)
        # Step 2: Create an RGBD image from the RGB and depth image
        depth_o3d = o3d.geometry.Image(depth)
        image_o3d = o3d.geometry.Image(np.array(image))
        rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(image_o3d, depth_o3d, convert_rgb_to_intensity=False)
        # Step 3: Create a PointCloud from the RGBD image
        pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, o3d.camera.PinholeCameraIntrinsic(o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault))
        # Step 4: Convert PointCloud data to a NumPy array
        points = np.asarray(pcd.points)
        colors = np.asarray(pcd.colors)
        print(points.shape, colors.shape)
        return points, colors
    
    def generate_fig(self, image):
        points, colors = self.generate_pcl(image)
        data = {'x': points[:, 0], 'y': points[:, 1], 'z': points[:, 2],
            'red': colors[:, 0], 'green': colors[:, 1], 'blue': colors[:, 2]}
        df = pd.DataFrame(data)
        size = np.zeros(len(df))
        size[:] = 0.01
        # Step 6: Create a 3D scatter plot using Plotly Express
        fig = px.scatter_3d(df, x='x', y='y', z='z', color='red', size=size)
        return fig
     
    def generate_fig2(self, image):
        points, colors = self.generate_pcl(image)
        # Step 6: Create a 3D scatter plot using Plotly Express
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(points,size=0.01, c=colors, marker='o')
        return fig
    
    def generate_obj_rgb(self, image, n_samples, cube_size):
        # Step 1: Create a point cloud
        point_cloud, color_array = self.generate_pcl(image)
        #point_cloud, color_array = point_cloud[mask.ravel()[:-1]], color_array[mask.ravel()[:-1]]
        # sample 1000 points
        idxs = np.random.choice(len(point_cloud), int(n_samples))
        point_cloud = point_cloud[idxs]
        color_array = color_array[idxs]
        # Create a mesh to hold the colored cubes
        mesh = o3d.geometry.TriangleMesh()
        # Create cubes and add them to the mesh
        for point, color in zip(point_cloud, color_array):
            cube = o3d.geometry.TriangleMesh.create_box(width=cube_size, height=cube_size, depth=cube_size)
            cube.translate(-point)
            cube.paint_uniform_color(color)
            mesh += cube
        # Save the mesh to an .obj file
        output_file = "./cloud.obj"
        o3d.io.write_triangle_mesh(output_file, mesh)
        return output_file

    def generate_obj_masks(self, image, n_samples, masks, cube_size):
        # Generate a point cloud
        point_cloud, color_array = self.generate_pcl(image)
        print(point_cloud.shape)
        mesh = o3d.geometry.TriangleMesh()
        # Create cubes and add them to the mesh
        cs = [(255,0,0),(0,255,0),(0,0,255)]
        for c,(mask, _) in zip(cs, masks):
            #if len(mask) == len(point_cloud):
            #    mask = mask.ravel()
            #else:
            #    mask = mask.ravel()[:-1]
            mask = mask.ravel()
            point_cloud_subset, color_array_subset = point_cloud[mask], color_array[mask]
            idxs = np.random.choice(len(point_cloud_subset), int(n_samples))
            point_cloud_subset = point_cloud_subset[idxs]
            for point in point_cloud_subset:
                cube = o3d.geometry.TriangleMesh.create_box(width=cube_size, height=cube_size, depth=cube_size)
                cube.translate(-point)
                cube.paint_uniform_color(c)
                mesh += cube
        # Save the mesh to an .obj file
        output_file = "./cloud.obj"
        o3d.io.write_triangle_mesh(output_file, mesh)
        return output_file
    



class SegmentPredictor:
    def __init__(self):
        MODEL_TYPE = "vit_h"
        checkpoint = "sam_vit_h_4b8939.pth"
        sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
        # Select device
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        sam.to(device=self.device)
        self.mask_generator = SamAutomaticMaskGenerator(sam)
        self.conditioned_pred = SamPredictor(sam)
    
    def encode(self, image):
        image = np.array(image)
        self.conditioned_pred.set_image(image)
    
    def cond_pred(self, pts, lbls):
        lbls = np.array(lbls)
        pts = np.array(pts)
        
        print(pts)
        print(lbls)

        masks, _, _ = self.conditioned_pred.predict(
            point_coords=pts,
            point_labels=lbls,
            multimask_output=True
            )
        return masks


    def segment_everything(self, image):
        image = np.array(image)
        sam_result = self.mask_generator.generate(image)
        mask_annotator = sv.MaskAnnotator()
        detections = sv.Detections.from_sam(sam_result=sam_result)
        annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
        return annotated_image