import numpy as np
import open3d as o3d
import open3d as o3d
import plotly.express as px
import numpy as np
import pandas as pd
from inference import DepthPredictor
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def create_3d_obj(rgb_image, depth_image, depth=10, path="./image.gltf"):
    depth_o3d = o3d.geometry.Image(depth_image)
    image_o3d = o3d.geometry.Image(rgb_image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
        image_o3d, depth_o3d, convert_rgb_to_intensity=False
    )
    w = int(depth_image.shape[1])
    h = int(depth_image.shape[0])

    camera_intrinsic = o3d.camera.PinholeCameraIntrinsic()
    camera_intrinsic.set_intrinsics(w, h, 500, 500, w / 2, h / 2)

    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, camera_intrinsic)

    print("normals")
    pcd.normals = o3d.utility.Vector3dVector(
        np.zeros((1, 3))
    )  # invalidate existing normals
    pcd.estimate_normals(
        search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30)
    )
    pcd.orient_normals_towards_camera_location(
        camera_location=np.array([0.0, 0.0, 1000.0])
    )
    pcd.transform([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
    pcd.transform([[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])

    print("run Poisson surface reconstruction")
    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
        mesh_raw, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
            pcd, depth=depth, width=0, scale=1.1, linear_fit=True
        )

    voxel_size = max(mesh_raw.get_max_bound() - mesh_raw.get_min_bound()) / 256
    print(f"voxel_size = {voxel_size:e}")
    mesh = mesh_raw.simplify_vertex_clustering(
        voxel_size=voxel_size,
        contraction=o3d.geometry.SimplificationContraction.Average,
    )

    # vertices_to_remove = densities < np.quantile(densities, 0.001)
    # mesh.remove_vertices_by_mask(vertices_to_remove)
    bbox = pcd.get_axis_aligned_bounding_box()
    mesh_crop = mesh.crop(bbox)
    gltf_path = path
    o3d.io.write_triangle_mesh(gltf_path, mesh_crop, write_triangle_uvs=True)
    return gltf_path


def create_3d_pc(rgb_image, depth_image, depth=10):
    depth_image = depth_image.astype(np.float32)  # Convert depth map to float32
    depth_o3d = o3d.geometry.Image(depth_image)
    image_o3d = o3d.geometry.Image(rgb_image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
        image_o3d, depth_o3d, convert_rgb_to_intensity=False
    )

    w = int(depth_image.shape[1])
    h = int(depth_image.shape[0])

    # Specify camera intrinsic parameters (modify based on actual camera)
    fx = 500
    fy = 500
    cx = w / 2
    cy = h / 2

    camera_intrinsic = o3d.camera.PinholeCameraIntrinsic(w, h, fx, fy, cx, cy)

    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, camera_intrinsic)

    print("Estimating normals...")
    pcd.estimate_normals(
        search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30)
    )
    pcd.orient_normals_towards_camera_location(
        camera_location=np.array([0.0, 0.0, 1000.0])
    )

    # Save the point cloud as a PLY file
    filename = "pc.pcd"
    o3d.io.write_point_cloud(filename, pcd)

    return filename  # Return the file path where the PLY file is saved


def point_cloud(rgb_image):
    depth_predictor = DepthPredictor()
    depth_result = depth_predictor.predict(rgb_image)
    # Step 2: Create an RGBD image from the RGB and depth image
    depth_o3d = o3d.geometry.Image(depth_image)
    image_o3d = o3d.geometry.Image(rgb_image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
        image_o3d, depth_o3d, convert_rgb_to_intensity=False
    )
    # Step 3: Create a PointCloud from the RGBD image
    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
        rgbd_image,
        o3d.camera.PinholeCameraIntrinsic(
            o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault
        ),
    )
    # Step 4: Convert PointCloud data to a NumPy array
    points = np.asarray(pcd.points)
    colors = np.asarray(pcd.colors)
    # Step 5: Create a DataFrame from the NumPy arrays
    data = {
        "x": points[:, 0],
        "y": points[:, 1],
        "z": points[:, 2],
        "red": colors[:, 0],
        "green": colors[:, 1],
        "blue": colors[:, 2],
    }
    df = pd.DataFrame(data)
    size = np.zeros(len(df))
    size[:] = 0.01
    # Step 6: Create a 3D scatter plot using Plotly Express
    fig = px.scatter_3d(df, x="x", y="y", z="z", color="red", size=size)

    return fig


def array_PCL(rgb_image, depth_image):
    FX_RGB = 5.1885790117450188e02
    FY_RGB = 5.1946961112127485e02
    CX_RGB = 3.2558244941119034e0
    CY_RGB = 2.5373616633400465e02
    FX_DEPTH = FX_RGB
    FY_DEPTH = FY_RGB
    CX_DEPTH = CX_RGB
    CY_DEPTH = CY_RGB
    height = depth_image.shape[0]
    width = depth_image.shape[1]
    # compute indices:
    jj = np.tile(range(width), height)
    ii = np.repeat(range(height), width)

    # Compute constants:
    xx = (jj - CX_DEPTH) / FX_DEPTH
    yy = (ii - CY_DEPTH) / FY_DEPTH

    # transform depth image to vector of z:
    length = height * width
    z = depth_image.reshape(length)

    # compute point cloud
    pcd = np.dstack((xx * z, yy * z, z)).reshape((length, 3))
    # cam_RGB = np.apply_along_axis(np.linalg.inv(R).dot, 1, pcd) - np.linalg.inv(R).dot(T)
    xx_rgb = (
        ((rgb_image[:, 0] * FX_RGB) / rgb_image[:, 2] + CX_RGB + width / 2)
        .astype(int)
        .clip(0, width - 1)
    )
    yy_rgb = (
        ((rgb_image[:, 1] * FY_RGB) / rgb_image[:, 2] + CY_RGB)
        .astype(int)
        .clip(0, height - 1)
    )
    # colors = rgb_image[yy_rgb, xx_rgb]/255
    return pcd  # , colors


def generate_PCL(image):
    depth_predictor = DepthPredictor()
    depth_result = depth_predictor.predict(image)
    image = np.array(image)
    pcd = array_PCL(image, depth_result)
    fig = px.scatter_3d(x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], size_max=0.01)
    return fig


def plot_PCL(rgb_image, depth_image):
    pcd, colors = array_PCL(rgb_image, depth_image)
    fig = px.scatter_3d(
        x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], color=colors, size_max=0.1
    )
    return fig


def PCL3(image):
    depth_predictor = DepthPredictor()
    depth_result = depth_predictor.predict(image)
    image = np.array(image)
    # Step 2: Create an RGBD image from the RGB and depth image
    depth_o3d = o3d.geometry.Image(depth_result)
    image_o3d = o3d.geometry.Image(image)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
        image_o3d, depth_o3d, convert_rgb_to_intensity=False
    )
    # Step 3: Create a PointCloud from the RGBD image
    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
        rgbd_image,
        o3d.camera.PinholeCameraIntrinsic(
            o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault
        ),
    )
    # Step 4: Convert PointCloud data to a NumPy array
    vis = o3d.visualization.Visualizer()
    vis.add_geometry(pcd)
    # Step 4: Convert PointCloud data to a NumPy array
    points = np.asarray(pcd.points)
    colors = np.asarray(pcd.colors)
    sizes = np.zeros(colors.shape)
    sizes[:] = 0.01
    colors = [tuple(c) for c in colors]
    fig = plt.figure()
    # ax = fig.add_subplot(111, projection='3d')
    ax = Axes3D(fig)
    print("plotting...")
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=colors, s=0.01)
    print("Plot Succesful")
    # data = {'x': points[:, 0], 'y': points[:, 1], 'z': points[:, 2], 'sizes': sizes[:, 0]}
    # df = pd.DataFrame(data)
    # Step 6: Create a 3D scatter plot using Plotly Express
    # fig = px.scatter_3d(df, x='x', y='y', z='z', color=colors, size="sizes")

    return fig


import numpy as np