import os import streamlit as st import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import cv2 import tensorflow as tf from huggingface_hub import hf_hub_download from PIL import Image import pickle # Obtener el token de Hugging Face desde las variables de entorno hf_token = os.getenv("HF_API_TOKEN") if hf_token is None: st.error("El token de Hugging Face no está configurado. Por favor, configura el secreto 'HF_API_TOKEN'.") st.stop() # Autenticarse con el token from huggingface_hub import login login(token=hf_token) # Cargar dataset dataset_repo = "VicGerardoPR/Traffic_sign_dataset" dataset_file = "train.p" try: dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file, repo_type="dataset") with open(dataset_path, 'rb') as file: train_data = pickle.load(file) except Exception as e: st.error(f"Error al cargar el dataset: {e}") st.stop() train_images, train_labels = train_data['features'], train_data['labels'] # Preprocesar datos def preprocess_data(images, labels): images = images.astype('float32') / 255.0 labels = pd.Series(labels).astype('category').cat.codes.to_numpy() return images, labels train_images, train_labels = preprocess_data(train_images, train_labels) # Cargar modelo model_repo = "VicGerardoPR/Traffic_sign_model" model_file = "traffic_sign_classifier.h5" try: model_path = hf_hub_download(repo_id=model_repo, filename=model_file, repo_type="model") model = tf.keras.models.load_model(model_path) except Exception as e: st.error(f"Error al cargar el modelo: {e}") st.stop() # Diccionario de clases de señales de tráfico classes = { 0: 'Speed limit (20km/h)', 1: 'Speed limit (30km/h)', 2: 'Speed limit (50km/h)', 3: 'Speed limit (60km/h)', 4: 'Speed limit (70km/h)', 5: 'Speed limit (80km/h)', 6: 'End of speed limit (80km/h)', 7: 'Speed limit (100km/h)', 8: 'Speed limit (120km/h)', 9: 'No passing', 10: 'No passing for vehicles over 3.5 metric tons', 11: 'Right-of-way at the next intersection', 12: 'Priority road', 13: 'Yield', 14: 'Stop', 15: 'No vehicles', 16: 'Vehicles over 3.5 metric tons prohibited', 17: 'No entry', 18: 'General caution', 19: 'Dangerous curve to the left', 20: 'Dangerous curve to the right', 21: 'Double curve', 22: 'Bumpy road', 23: 'Slippery road', 24: 'Road narrows on the right', 25: 'Road work', 26: 'Traffic signals', 27: 'Pedestrians', 28: 'Children crossing', 29: 'Bicycles crossing', 30: 'Beware of ice/snow', 31: 'Wild animals crossing', 32: 'End of all speed and passing limits', 33: 'Turn right ahead', 34: 'Turn left ahead', 35: 'Ahead only', 36: 'Go straight or right', 37: 'Go straight or left', 38: 'Keep right', 39: 'Keep left', 40: 'Roundabout mandatory', 41: 'End of no passing', 42: 'End of no passing by vehicles over 3.5 metric tons' } def predict(image): image = np.array(image) image = cv2.resize(image, (32, 32)) image = image / 255.0 image = np.expand_dims(image, axis=0) predictions = model.predict(image) class_idx = np.argmax(predictions, axis=1)[0] return classes[class_idx] # Título y descripción de la aplicación st.title("Traffic Sign Classifier") st.write("Esta aplicación clasifica señales de tráfico usando un modelo de CNN.") # Mostrar ejemplos de imágenes del conjunto de datos st.header("Ejemplos de Imágenes del Conjunto de Datos") fig, axes = plt.subplots(2, 5, figsize=(15, 6)) axes = axes.ravel() for i in range(10): idx = np.random.randint(0, len(train_images)) axes[i].imshow(train_images[idx]) axes[i].set_title(f"Clase: {train_labels[idx]}") axes[i].axis('off') st.pyplot(fig) # Permitir al usuario cargar una imagen st.header("Carga tu Propia Imagen de Señal de Tráfico") uploaded_file = st.file_uploader("Elige una imagen...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Imagen Cargada', use_column_width=True) st.write("") st.write("Clasificando...") label = predict(image) st.write(f"Esta señal es: {label}")