import os
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import tensorflow as tf
from huggingface_hub import hf_hub_download
from PIL import Image
import pickle

# Obtener el token de Hugging Face desde las variables de entorno
hf_token = os.getenv("HF_API_TOKEN")

if hf_token is None:
    st.error("El token de Hugging Face no está configurado. Por favor, configura el secreto 'HF_API_TOKEN'.")
    st.stop()

# Autenticarse con el token
from huggingface_hub import login
login(token=hf_token)

# Cargar dataset
dataset_repo = "VicGerardoPR/Traffic_sign_dataset"
dataset_file = "train.p"

try:
    dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file, repo_type="dataset")
    with open(dataset_path, 'rb') as file:
        train_data = pickle.load(file)
except Exception as e:
    st.error(f"Error al cargar el dataset: {e}")
    st.stop()

train_images, train_labels = train_data['features'], train_data['labels']

# Preprocesar datos
def preprocess_data(images, labels):
    images = images.astype('float32') / 255.0
    labels = pd.Series(labels).astype('category').cat.codes.to_numpy()
    return images, labels

train_images, train_labels = preprocess_data(train_images, train_labels)

# Cargar modelo
model_repo = "VicGerardoPR/Traffic_sign_model"
model_file = "traffic_sign_classifier.h5"

try:
    model_path = hf_hub_download(repo_id=model_repo, filename=model_file, repo_type="model")
    model = tf.keras.models.load_model(model_path)
except Exception as e:
    st.error(f"Error al cargar el modelo: {e}")
    st.stop()

# Diccionario de clases de señales de tráfico
classes = {
    0: 'Speed limit (20km/h)',
    1: 'Speed limit (30km/h)',
    2: 'Speed limit (50km/h)',
    3: 'Speed limit (60km/h)',
    4: 'Speed limit (70km/h)',
    5: 'Speed limit (80km/h)',
    6: 'End of speed limit (80km/h)',
    7: 'Speed limit (100km/h)',
    8: 'Speed limit (120km/h)',
    9: 'No passing',
    10: 'No passing for vehicles over 3.5 metric tons',
    11: 'Right-of-way at the next intersection',
    12: 'Priority road',
    13: 'Yield',
    14: 'Stop',
    15: 'No vehicles',
    16: 'Vehicles over 3.5 metric tons prohibited',
    17: 'No entry',
    18: 'General caution',
    19: 'Dangerous curve to the left',
    20: 'Dangerous curve to the right',
    21: 'Double curve',
    22: 'Bumpy road',
    23: 'Slippery road',
    24: 'Road narrows on the right',
    25: 'Road work',
    26: 'Traffic signals',
    27: 'Pedestrians',
    28: 'Children crossing',
    29: 'Bicycles crossing',
    30: 'Beware of ice/snow',
    31: 'Wild animals crossing',
    32: 'End of all speed and passing limits',
    33: 'Turn right ahead',
    34: 'Turn left ahead',
    35: 'Ahead only',
    36: 'Go straight or right',
    37: 'Go straight or left',
    38: 'Keep right',
    39: 'Keep left',
    40: 'Roundabout mandatory',
    41: 'End of no passing',
    42: 'End of no passing by vehicles over 3.5 metric tons'
}

def predict(image):
    image = np.array(image)
    image = cv2.resize(image, (32, 32))
    image = image / 255.0
    image = np.expand_dims(image, axis=0)

    predictions = model.predict(image)
    class_idx = np.argmax(predictions, axis=1)[0]
    return classes[class_idx]

# Título y descripción de la aplicación
st.title("Traffic Sign Classifier")
st.write("Esta aplicación clasifica señales de tráfico usando un modelo de CNN.")

# Mostrar ejemplos de imágenes del conjunto de datos
st.header("Ejemplos de Imágenes del Conjunto de Datos")
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.ravel()
for i in range(10):
    idx = np.random.randint(0, len(train_images))
    axes[i].imshow(train_images[idx])
    axes[i].set_title(f"Clase: {train_labels[idx]}")
    axes[i].axis('off')
st.pyplot(fig)

# Permitir al usuario cargar una imagen
st.header("Carga tu Propia Imagen de Señal de Tráfico")
uploaded_file = st.file_uploader("Elige una imagen...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Imagen Cargada', use_column_width=True)
    st.write("")
    st.write("Clasificando...")
    label = predict(image)
    st.write(f"Esta señal es: {label}")