import numpy as np
import open3d as o3d
import open3d as o3d
import plotly.express as px
import numpy as np
import pandas as pd
from inference import DepthPredictor
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def create_3d_obj(rgb_image, depth_image, depth=10, path='./image.gltf'):
    depth_o3d = o3d.geometry.Image(depth_image)
    image_o3d = o3d.geometry.Image(rgb_image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
        image_o3d, depth_o3d, convert_rgb_to_intensity=False)
    w = int(depth_image.shape[1])
    h = int(depth_image.shape[0])

    camera_intrinsic = o3d.camera.PinholeCameraIntrinsic()
    camera_intrinsic.set_intrinsics(w, h, 500, 500, w/2, h/2)

    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
        rgbd_image, camera_intrinsic)

    print('normals')
    pcd.normals = o3d.utility.Vector3dVector(
        np.zeros((1, 3)))  # invalidate existing normals
    pcd.estimate_normals(
        search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30))
    pcd.orient_normals_towards_camera_location(
        camera_location=np.array([0., 0., 1000.]))
    pcd.transform([[1, 0, 0, 0],
                   [0, -1, 0, 0],
                   [0, 0, -1, 0],
                   [0, 0, 0, 1]])
    pcd.transform([[-1, 0, 0, 0],
                   [0, 1, 0, 0],
                   [0, 0, 1, 0],
                   [0, 0, 0, 1]])

    print('run Poisson surface reconstruction')
    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
        mesh_raw, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
            pcd, depth=depth, width=0, scale=1.1, linear_fit=True)

    voxel_size = max(mesh_raw.get_max_bound() - mesh_raw.get_min_bound()) / 256
    print(f'voxel_size = {voxel_size:e}')
    mesh = mesh_raw.simplify_vertex_clustering(
        voxel_size=voxel_size,
        contraction=o3d.geometry.SimplificationContraction.Average)

    # vertices_to_remove = densities < np.quantile(densities, 0.001)
    # mesh.remove_vertices_by_mask(vertices_to_remove)
    bbox = pcd.get_axis_aligned_bounding_box()
    mesh_crop = mesh.crop(bbox)
    gltf_path = path
    o3d.io.write_triangle_mesh(
        gltf_path, mesh_crop, write_triangle_uvs=True)
    return gltf_path


def create_3d_pc(rgb_image, depth_image, depth=10):
    depth_image = depth_image.astype(np.float32)  # Convert depth map to float32
    depth_o3d = o3d.geometry.Image(depth_image)
    image_o3d = o3d.geometry.Image(rgb_image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
        image_o3d, depth_o3d, convert_rgb_to_intensity=False)

    w = int(depth_image.shape[1])
    h = int(depth_image.shape[0])

    # Specify camera intrinsic parameters (modify based on actual camera)
    fx = 500
    fy = 500
    cx = w / 2
    cy = h / 2

    camera_intrinsic = o3d.camera.PinholeCameraIntrinsic(w, h, fx, fy, cx, cy)

    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
        rgbd_image, camera_intrinsic)

    print('Estimating normals...')
    pcd.estimate_normals(
        search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30))
    pcd.orient_normals_towards_camera_location(
        camera_location=np.array([0., 0., 1000.]))

    # Save the point cloud as a PLY file
    filename = "pc.pcd"
    o3d.io.write_point_cloud(filename, pcd)

    return filename # Return the file path where the PLY file is saved


def point_cloud(rgb_image):
    depth_predictor = DepthPredictor()
    depth_result = depth_predictor.predict(rgb_image)
    # Step 2: Create an RGBD image from the RGB and depth image
    depth_o3d = o3d.geometry.Image(depth_image)
    image_o3d = o3d.geometry.Image(rgb_image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(image_o3d, depth_o3d, convert_rgb_to_intensity=False)
    # Step 3: Create a PointCloud from the RGBD image
    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, o3d.camera.PinholeCameraIntrinsic(o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault))
    # Step 4: Convert PointCloud data to a NumPy array
    points = np.asarray(pcd.points)
    colors = np.asarray(pcd.colors)
    # Step 5: Create a DataFrame from the NumPy arrays
    data = {'x': points[:, 0], 'y': points[:, 1], 'z': points[:, 2],
            'red': colors[:, 0], 'green': colors[:, 1], 'blue': colors[:, 2]}
    df = pd.DataFrame(data)
    size = np.zeros(len(df))
    size[:] = 0.01
    # Step 6: Create a 3D scatter plot using Plotly Express
    fig = px.scatter_3d(df, x='x', y='y', z='z', color='red', size=size)
    
    
    return fig

def array_PCL(rgb_image, depth_image):
    FX_RGB = 5.1885790117450188e+02
    FY_RGB = 5.1946961112127485e+02
    CX_RGB = 3.2558244941119034e+0
    CY_RGB = 2.5373616633400465e+02
    FX_DEPTH = FX_RGB
    FY_DEPTH = FY_RGB
    CX_DEPTH = CX_RGB
    CY_DEPTH = CY_RGB
    height = depth_image.shape[0]
    width = depth_image.shape[1]
    # compute indices:
    jj = np.tile(range(width), height)
    ii = np.repeat(range(height), width)

    # Compute constants:
    xx = (jj - CX_DEPTH) / FX_DEPTH
    yy = (ii - CY_DEPTH) / FY_DEPTH

    # transform depth image to vector of z:
    length = height * width
    z = depth_image.reshape(length)

    # compute point cloud
    pcd = np.dstack((xx * z, yy * z, z)).reshape((length, 3))
    #cam_RGB = np.apply_along_axis(np.linalg.inv(R).dot, 1, pcd) - np.linalg.inv(R).dot(T)
    xx_rgb = ((rgb_image[:, 0] * FX_RGB) / rgb_image[:, 2] + CX_RGB + width / 2).astype(int).clip(0, width - 1)
    yy_rgb = ((rgb_image[:, 1] * FY_RGB) / rgb_image[:, 2] + CY_RGB).astype(int).clip(0, height - 1)
    #colors = rgb_image[yy_rgb, xx_rgb]/255
    return pcd#, colors

def generate_PCL(image):
    depth_predictor = DepthPredictor()
    depth_result = depth_predictor.predict(image)
    image = np.array(image)
    pcd = array_PCL(image, depth_result)
    fig = px.scatter_3d(x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], size_max=0.01)
    return fig


def plot_PCL(rgb_image, depth_image):
    pcd, colors = array_PCL(rgb_image, depth_image)
    fig = px.scatter_3d(x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], color=colors, size_max=0.1)
    return fig


def PCL3(image):
    depth_predictor = DepthPredictor()
    depth_result = depth_predictor.predict(image)
    image = np.array(image)
    # Step 2: Create an RGBD image from the RGB and depth image
    depth_o3d = o3d.geometry.Image(depth_result)
    image_o3d = o3d.geometry.Image(image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(image_o3d, depth_o3d, convert_rgb_to_intensity=False)
    # Step 3: Create a PointCloud from the RGBD image
    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, o3d.camera.PinholeCameraIntrinsic(o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault))
    # Step 4: Convert PointCloud data to a NumPy array
    vis = o3d.visualization.Visualizer()
    vis.add_geometry(pcd)
    # Step 4: Convert PointCloud data to a NumPy array
    points = np.asarray(pcd.points)
    colors = np.asarray(pcd.colors)
    sizes = np.zeros(colors.shape)
    sizes[:] = 0.01
    colors = [tuple(c) for c in colors]
    fig = plt.figure()
    #ax = fig.add_subplot(111, projection='3d')
    ax = Axes3D(fig)
    print("plotting...")
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=colors, s=0.01)
    print("Plot Succesful")
    #data = {'x': points[:, 0], 'y': points[:, 1], 'z': points[:, 2], 'sizes': sizes[:, 0]}
    #df = pd.DataFrame(data)
    # Step 6: Create a 3D scatter plot using Plotly Express
    #fig = px.scatter_3d(df, x='x', y='y', z='z', color=colors, size="sizes")
    
    return fig

import numpy as np