import gradio as gr
import numpy as np
import json
import joblib
import tensorflow as tf
import pandas as pd
from joblib import load
from tensorflow.keras.models import load_model
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import os
import sklearn

# Display library versions
print(f"Gradio version: {gr.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Scikit-learn version: {sklearn.__version__}")
print(f"Joblib version: {joblib.__version__}")
print(f"TensorFlow version: {tf.__version__}")
print(f"Pandas version: {pd.__version__}")

# Directory paths for the saved models
script_dir = os.path.dirname(os.path.abspath(__file__))
scaler_path = os.path.join(script_dir, 'toolkit', 'scaler_X.json')
rf_model_path = os.path.join(script_dir, 'toolkit', 'rf_model.joblib')
mlp_model_path = os.path.join(script_dir, 'toolkit', 'mlp_model.keras')
meta_model_path = os.path.join(script_dir, 'toolkit', 'meta_model.joblib')
image_path = os.path.join(script_dir, 'toolkit', 'car.png')

# Load the scaler and models
try:
    # Load the scaler
    with open(scaler_path, 'r') as f:
        scaler_params = json.load(f)
    scaler_X = MinMaxScaler()
    scaler_X.scale_ = np.array(scaler_params["scale_"])
    scaler_X.min_ = np.array(scaler_params["min_"])
    scaler_X.data_min_ = np.array(scaler_params["data_min_"])
    scaler_X.data_max_ = np.array(scaler_params["data_max_"])
    scaler_X.data_range_ = np.array(scaler_params["data_range_"])
    scaler_X.n_features_in_ = scaler_params["n_features_in_"]
    scaler_X.feature_names_in_ = np.array(scaler_params["feature_names_in_"])

    # Load the models
    loaded_rf_model = load(rf_model_path)
    print("Random Forest model loaded successfully.")
    loaded_mlp_model = load_model(mlp_model_path)
    print("MLP model loaded successfully.")
    loaded_meta_model = load(meta_model_path)
    print("Meta model loaded successfully.")
except Exception as e:
    print(f"Error loading models or scaler: {e}")

def predict_and_plot(velocity, temperature, precipitation, humidity):
    try:
        # Prepare the example data
        example_data = pd.DataFrame({
            'Velocity(mph)': [velocity],
            'Temperature': [temperature],
            'Precipitation': [precipitation],
            'Humidity': [humidity]
        })

        # Scale the example data
        example_data_scaled = scaler_X.transform(example_data)

        # Function to predict contamination levels and gradients
        def predict_contamination_and_gradients(example_data_scaled):
            # Predict using MLP model
            mlp_predictions_contamination, mlp_predictions_gradients = loaded_mlp_model.predict(example_data_scaled)

            # Predict using RF model
            rf_predictions = loaded_rf_model.predict(example_data_scaled)

            # Combine predictions for meta model
            combined_features = np.concatenate([np.concatenate([mlp_predictions_contamination, mlp_predictions_gradients], axis=1), rf_predictions], axis=1)

            # Predict using meta model
            meta_predictions = loaded_meta_model.predict(combined_features)

            return meta_predictions[:, :6], meta_predictions[:, 6:]  # Split predictions into contamination and gradients

        # Predict contamination levels and gradients for the single example
        contamination_levels, gradients = predict_contamination_and_gradients(example_data_scaled)

        # Simulate contamination levels at multiple time intervals
        time_intervals = np.arange(0, 3601, 60)  # Simulating time intervals from 0 to 600 seconds

        # Generate simulated contamination levels (linear interpolation between predicted values)
        simulated_contamination_levels = np.array([
            np.linspace(contamination_levels[0][i], contamination_levels[0][i] * 2, len(time_intervals))
            for i in range(contamination_levels.shape[1])
        ]).T

        # Function to calculate cleaning time using linear interpolation
        def calculate_cleaning_time(time_intervals, contamination_levels, threshold=0.4):
            cleaning_times = []
            for i in range(contamination_levels.shape[1]):
                levels = contamination_levels[:, i]
                for j in range(1, len(levels)):
                    if levels[j-1] <= threshold <= levels[j]:
                        # Linear interpolation
                        t1, t2 = time_intervals[j-1], time_intervals[j]
                        c1, c2 = levels[j-1], levels[j]
                        cleaning_time = t1 + (threshold - c1) * (t2 - t1) / (c2 - c1)
                        cleaning_times.append(cleaning_time)
                        break
                else:
                    cleaning_times.append(time_intervals[-1])  # If threshold is not reached
            return cleaning_times

        # Calculate cleaning times for all 6 lidars
        cleaning_times = calculate_cleaning_time(time_intervals, simulated_contamination_levels)

        # Lidar names
        lidar_names = ['F/L', 'F/R', 'Left', 'Right', 'Roof', 'Rear']

        # Plot the graph
        fig, ax = plt.subplots(figsize=(12, 8))

        for i in range(simulated_contamination_levels.shape[1]):
            ax.plot(time_intervals, simulated_contamination_levels[:, i], label=f'{lidar_names[i]}')
            ax.axhline(y=0.4, color='r', linestyle='--', label='Contamination Threshold' if i == 0 else "")
            if i < len(cleaning_times):
                ax.scatter(cleaning_times[i], 0.4, color='k')  # Mark the cleaning time point

        ax.set_title('Contamination Levels Over Time for Each Lidar')
        ax.set_xlabel('Time (seconds)')
        ax.set_ylabel('Contamination Level')
        ax.legend()
        ax.grid(True)
        
        # Flatten the results into a single list of 19 outputs (1 plot + 6 contamination + 6 gradients + 6 cleaning times)
        plot_output = fig
        contamination_output = [f"{val * 100:.2f}%" for val in contamination_levels[0]]
        gradients_output = [f"{val:.4f}" for val in gradients[0]]
        cleaning_time_output = [f"{val:.2f}" for val in cleaning_times]

        return [plot_output] + contamination_output + gradients_output + cleaning_time_output

    except Exception as e:
        print(f"Error in Gradio interface: {e}")
        return [plt.figure()] + ["Error"] * 18

inputs = [
    gr.Slider(minimum=0, maximum=100, value=50, step=0.05, label="Velocity (mph)"),
    gr.Slider(minimum=-2, maximum=30, value=0, step=0.5, label="Temperature (°C)"),
    gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Precipitation (inch)"),
    gr.Slider(minimum=0, maximum=100, value=50, label="Humidity (%)")
]

contamination_outputs = [
    gr.Textbox(label="Front Left Contamination"),
    gr.Textbox(label="Front Right Contamination"),
    gr.Textbox(label="Left Contamination"),
    gr.Textbox(label="Right Contamination"),
    gr.Textbox(label="Roof Contamination"),
    gr.Textbox(label="Rear Contamination")
]

gradients_outputs = [
    gr.Textbox(label="Front Left Gradient"),
    gr.Textbox(label="Front Right Gradient"),
    gr.Textbox(label="Left Gradient"),
    gr.Textbox(label="Right Gradient"),
    gr.Textbox(label="Roof Gradient"),
    gr.Textbox(label="Rear Gradient")
]

cleaning_time_outputs = [
    gr.Textbox(label="Front Left Cleaning Time"),
    gr.Textbox(label="Front Right Cleaning Time"),
    gr.Textbox(label="Left Cleaning Time"),
    gr.Textbox(label="Right Cleaning Time"),
    gr.Textbox(label="Roof Cleaning Time"),
    gr.Textbox(label="Rear Cleaning Time")
]

with gr.Blocks(css=".column-container {height: 100%; display: flex; flex-direction: column; justify-content: space-between;}") as demo:
    gr.Markdown("<h1 style='text-align: center;'>Environmental Factor-Based Contamination, Gradient, & Cleaning Time Prediction</h1>")
    gr.Markdown("This application predicts the contamination levels, gradients, and cleaning times for different parts of a car's LiDAR system based on environmental factors such as velocity, temperature, precipitation, and humidity.")
    
    # Top Section: Inputs and Car Image
    with gr.Row():
        with gr.Column(scale=2, elem_classes="column-container"):
            gr.Markdown("### Input Parameters")
            for inp in inputs:
                inp.render()
            submit_button = gr.Button(value="Submit", variant="primary")
            clear_button = gr.Button(value="Clear")

        with gr.Column(scale=1):
            gr.Markdown("### Location of LiDARs")
            gr.Image(image_path)

    # Bottom Section: Outputs (Three columns)
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("### Contamination Predictions ± 7.1%")
            for out in contamination_outputs:
                out.render()

        with gr.Column(scale=2):
            gr.Markdown("### Gradient Predictions")
            for out in gradients_outputs:
                out.render()

        with gr.Column(scale=2):
            gr.Markdown("### Cleaning Time (s) Predictions")
            for out in cleaning_time_outputs:
                out.render()

    # Graph below the outputs
    with gr.Row():
        plot_output = gr.Plot(label="Contamination Levels Over Time")

    submit_button.click(
        fn=predict_and_plot, 
        inputs=inputs, 
        outputs=[plot_output] + contamination_outputs + gradients_outputs + cleaning_time_outputs
    )
    clear_button.click(fn=lambda: None)

demo.launch()