import numpy as np
import open3d as o3d
import open3d as o3d
import plotly.express as px
import numpy as np
import pandas as pd
from inference import DepthPredictor

def create_3d_obj(rgb_image, depth_image, depth=10, path='./image.gltf'):
    depth_o3d = o3d.geometry.Image(depth_image)
    image_o3d = o3d.geometry.Image(rgb_image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
        image_o3d, depth_o3d, convert_rgb_to_intensity=False)
    w = int(depth_image.shape[1])
    h = int(depth_image.shape[0])

    camera_intrinsic = o3d.camera.PinholeCameraIntrinsic()
    camera_intrinsic.set_intrinsics(w, h, 500, 500, w/2, h/2)

    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
        rgbd_image, camera_intrinsic)

    print('normals')
    pcd.normals = o3d.utility.Vector3dVector(
        np.zeros((1, 3)))  # invalidate existing normals
    pcd.estimate_normals(
        search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30))
    pcd.orient_normals_towards_camera_location(
        camera_location=np.array([0., 0., 1000.]))
    pcd.transform([[1, 0, 0, 0],
                   [0, -1, 0, 0],
                   [0, 0, -1, 0],
                   [0, 0, 0, 1]])
    pcd.transform([[-1, 0, 0, 0],
                   [0, 1, 0, 0],
                   [0, 0, 1, 0],
                   [0, 0, 0, 1]])

    print('run Poisson surface reconstruction')
    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
        mesh_raw, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
            pcd, depth=depth, width=0, scale=1.1, linear_fit=True)

    voxel_size = max(mesh_raw.get_max_bound() - mesh_raw.get_min_bound()) / 256
    print(f'voxel_size = {voxel_size:e}')
    mesh = mesh_raw.simplify_vertex_clustering(
        voxel_size=voxel_size,
        contraction=o3d.geometry.SimplificationContraction.Average)

    # vertices_to_remove = densities < np.quantile(densities, 0.001)
    # mesh.remove_vertices_by_mask(vertices_to_remove)
    bbox = pcd.get_axis_aligned_bounding_box()
    mesh_crop = mesh.crop(bbox)
    gltf_path = path
    o3d.io.write_triangle_mesh(
        gltf_path, mesh_crop, write_triangle_uvs=True)
    return gltf_path


def create_3d_pc(rgb_image, depth_image, depth=10):
    depth_image = depth_image.astype(np.float32)  # Convert depth map to float32
    depth_o3d = o3d.geometry.Image(depth_image)
    image_o3d = o3d.geometry.Image(rgb_image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
        image_o3d, depth_o3d, convert_rgb_to_intensity=False)

    w = int(depth_image.shape[1])
    h = int(depth_image.shape[0])

    # Specify camera intrinsic parameters (modify based on actual camera)
    fx = 500
    fy = 500
    cx = w / 2
    cy = h / 2

    camera_intrinsic = o3d.camera.PinholeCameraIntrinsic(w, h, fx, fy, cx, cy)

    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
        rgbd_image, camera_intrinsic)

    print('Estimating normals...')
    pcd.estimate_normals(
        search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30))
    pcd.orient_normals_towards_camera_location(
        camera_location=np.array([0., 0., 1000.]))

    # Save the point cloud as a PLY file
    filename = "pc.pcd"
    o3d.io.write_point_cloud(filename, pcd)

    return filename # Return the file path where the PLY file is saved


def point_cloud(rgb_image, depth_image):
    # Step 2: Create an RGBD image from the RGB and depth image
    depth_o3d = o3d.geometry.Image(depth_image)
    image_o3d = o3d.geometry.Image(rgb_image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(image_o3d, depth_o3d, convert_rgb_to_intensity=False)
    # Step 3: Create a PointCloud from the RGBD image
    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, o3d.camera.PinholeCameraIntrinsic(o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault))
    # Step 4: Convert PointCloud data to a NumPy array
    points = np.asarray(pcd.points)
    colors = np.asarray(pcd.colors)
    # Step 5: Create a DataFrame from the NumPy arrays
    data = {'x': points[:, 0], 'y': points[:, 1], 'z': points[:, 2],
            'red': colors[:, 0], 'green': colors[:, 1], 'blue': colors[:, 2]}
    df = pd.DataFrame(data)
    # Step 6: Create a 3D scatter plot using Plotly Express
    fig = px.scatter_3d(df, x='x', y='y', z='z', color='red', size_max=0.1)
    
    return fig

def array_PCL(rgb_image, depth_image):
    FX_RGB = 5.1885790117450188e+02
    FY_RGB = 5.1946961112127485e+02
    CX_RGB = 3.2558244941119034e+0
    CY_RGB = 2.5373616633400465e+02
    FX_DEPTH = FX_RGB
    FY_DEPTH = FY_RGB
    CX_DEPTH = CX_RGB
    CY_DEPTH = CY_RGB
    height = depth_image.shape[0]
    width = depth_image.shape[1]
    # compute indices:
    jj = np.tile(range(width), height)
    ii = np.repeat(range(height), width)

    # Compute constants:
    xx = (jj - CX_DEPTH) / FX_DEPTH
    yy = (ii - CY_DEPTH) / FY_DEPTH

    # transform depth image to vector of z:
    length = height * width
    z = depth_image.reshape(length)

    # compute point cloud
    pcd = np.dstack((xx * z, yy * z, z)).reshape((length, 3))
    #cam_RGB = np.apply_along_axis(np.linalg.inv(R).dot, 1, pcd) - np.linalg.inv(R).dot(T)
    xx_rgb = ((rgb_image[:, 0] * FX_RGB) / rgb_image[:, 2] + CX_RGB + width / 2).astype(int).clip(0, width - 1)
    yy_rgb = ((rgb_image[:, 1] * FY_RGB) / rgb_image[:, 2] + CY_RGB).astype(int).clip(0, height - 1)
    #colors = rgb_image[yy_rgb, xx_rgb]/255
    return pcd#, colors

def generate_PCL(image):
    depth_predictor = DepthPredictor()
    depth_result = depth_predictor.predict(image)
    image = np.array(image)
    pcd = array_PCL(image, depth_result)
    fig = px.scatter_3d(x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], size_max=0.01)
    return fig


def plot_PCL(rgb_image, depth_image):
    pcd, colors = array_PCL(rgb_image, depth_image)
    fig = px.scatter_3d(x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], color=colors, size_max=0.1)
    return fig


def PCL3(image):
    depth_predictor = DepthPredictor()
    depth_result = depth_predictor.predict(image)
    image = np.array(image)
    # Step 2: Create an RGBD image from the RGB and depth image
    depth_o3d = o3d.geometry.Image(depth_result)
    image_o3d = o3d.geometry.Image(image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(image_o3d, depth_o3d, convert_rgb_to_intensity=False)
    # Step 3: Create a PointCloud from the RGBD image
    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, o3d.camera.PinholeCameraIntrinsic(o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault))
    # Step 4: Convert PointCloud data to a NumPy array
    vis = o3d.visualization.Visualizer()
    vis.add_geometry(pcd)
    return vis