import streamlit as st import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import cv2 import torch from transformers import AutoTokenizer, AutoModelForImageClassification from huggingface_hub import hf_hub_download from PIL import Image import pickle # Cargar dataset dataset_repo = "VicGerardoPR/Traffic_sign_dataset" dataset_file = "train.p" dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file) with open(dataset_path, 'rb') as file: train_data = pickle.load(file) train_images, train_labels = train_data['features'], train_data['labels'] # Preprocesar datos def preprocess_data(images, labels): images = images.astype('float32') / 255.0 labels = pd.Series(labels).astype('category').cat.codes.to_numpy() return images, labels train_images, train_labels = preprocess_data(train_images, train_labels) # Cargar modelo model_repo = "VicGerardoPR/Traffic_sign_model" model_file = "traffic_sign_classifier.h5" model_path = hf_hub_download(repo_id=model_repo, filename=model_file) model = torch.load(model_path, map_location=torch.device('cpu')) model.eval() # Diccionario de clases de señales de tráfico classes = { 0: 'Speed limit (20km/h)', 1: 'Speed limit (30km/h)', 2: 'Speed limit (50km/h)', 3: 'Speed limit (60km/h)', 4: 'Speed limit (70km/h)', 5: 'Speed limit (80km/h)', 6: 'End of speed limit (80km/h)', 7: 'Speed limit (100km/h)', 8: 'Speed limit (120km/h)', 9: 'No passing', 10: 'No passing for vehicles over 3.5 metric tons', 11: 'Right-of-way at the next intersection', 12: 'Priority road', 13: 'Yield', 14: 'Stop', 15: 'No vehicles', 16: 'Vehicles over 3.5 metric tons prohibited', 17: 'No entry', 18: 'General caution', 19: 'Dangerous curve to the left', 20: 'Dangerous curve to the right', 21: 'Double curve', 22: 'Bumpy road', 23: 'Slippery road', 24: 'Road narrows on the right', 25: 'Road work', 26: 'Traffic signals', 27: 'Pedestrians', 28: 'Children crossing', 29: 'Bicycles crossing', 30: 'Beware of ice/snow', 31: 'Wild animals crossing', 32: 'End of all speed and passing limits', 33: 'Turn right ahead', 34: 'Turn left ahead', 35: 'Ahead only', 36: 'Go straight or right', 37: 'Go straight or left', 38: 'Keep right', 39: 'Keep left', 40: 'Roundabout mandatory', 41: 'End of no passing', 42: 'End of no passing by vehicles over 3.5 metric tons' } def predict(image): image = np.array(image) image = cv2.resize(image, (32, 32)) image = image / 255.0 image = np.expand_dims(image, axis=0) image = np.transpose(image, (0, 3, 1, 2)) # Reordenar dimensiones para PyTorch image = torch.tensor(image).float() with torch.no_grad(): predictions = model(image) class_idx = predictions.argmax().item() return classes[class_idx] # Título y descripción de la aplicación st.title("Traffic Sign Classifier") st.write("Esta aplicación clasifica señales de tráfico usando un modelo de CNN.") # Mostrar ejemplos de imágenes del conjunto de datos st.header("Ejemplos de Imágenes del Conjunto de Datos") fig, axes = plt.subplots(2, 5, figsize=(15, 6)) axes = axes.ravel() for i in range(10): idx = np.random.randint(0, len(train_images)) axes[i].imshow(train_images[idx]) axes[i].set_title(f"Clase: {train_labels[idx]}") axes[i].axis('off') st.pyplot(fig) # Permitir al usuario cargar una imagen st.header("Carga tu Propia Imagen de Señal de Tráfico") uploaded_file = st.file_uploader("Elige una imagen...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Imagen Cargada', use_column_width=True) st.write("") st.write("Clasificando...") label = predict(image) st.write(f"Esta señal es: {label}")