|
import os |
|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import cv2 |
|
import tensorflow as tf |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
import pickle |
|
|
|
|
|
hf_token = os.getenv("HF_API_TOKEN") |
|
|
|
if hf_token is None: |
|
st.error("El token de Hugging Face no está configurado. Por favor, configura el secreto 'HF_API_TOKEN'.") |
|
st.stop() |
|
|
|
|
|
from huggingface_hub import login |
|
login(token=hf_token) |
|
|
|
|
|
dataset_repo = "VicGerardoPR/Traffic_sign_dataset" |
|
dataset_file = "train.p" |
|
|
|
try: |
|
dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file, repo_type="dataset") |
|
with open(dataset_path, 'rb') as file: |
|
train_data = pickle.load(file) |
|
except Exception as e: |
|
st.error(f"Error al cargar el dataset: {e}") |
|
st.stop() |
|
|
|
train_images, train_labels = train_data['features'], train_data['labels'] |
|
|
|
|
|
def preprocess_data(images, labels): |
|
images = images.astype('float32') / 255.0 |
|
labels = pd.Series(labels).astype('category').cat.codes.to_numpy() |
|
return images, labels |
|
|
|
train_images, train_labels = preprocess_data(train_images, train_labels) |
|
|
|
|
|
model_repo = "VicGerardoPR/Traffic_sign_model" |
|
model_file = "traffic_sign_classifier.h5" |
|
|
|
try: |
|
model_path = hf_hub_download(repo_id=model_repo, filename=model_file, repo_type="model") |
|
model = tf.keras.models.load_model(model_path) |
|
except Exception as e: |
|
st.error(f"Error al cargar el modelo: {e}") |
|
st.stop() |
|
|
|
|
|
classes = { |
|
0: 'Speed limit (20km/h)', |
|
1: 'Speed limit (30km/h)', |
|
2: 'Speed limit (50km/h)', |
|
3: 'Speed limit (60km/h)', |
|
4: 'Speed limit (70km/h)', |
|
5: 'Speed limit (80km/h)', |
|
6: 'End of speed limit (80km/h)', |
|
7: 'Speed limit (100km/h)', |
|
8: 'Speed limit (120km/h)', |
|
9: 'No passing', |
|
10: 'No passing for vehicles over 3.5 metric tons', |
|
11: 'Right-of-way at the next intersection', |
|
12: 'Priority road', |
|
13: 'Yield', |
|
14: 'Stop', |
|
15: 'No vehicles', |
|
16: 'Vehicles over 3.5 metric tons prohibited', |
|
17: 'No entry', |
|
18: 'General caution', |
|
19: 'Dangerous curve to the left', |
|
20: 'Dangerous curve to the right', |
|
21: 'Double curve', |
|
22: 'Bumpy road', |
|
23: 'Slippery road', |
|
24: 'Road narrows on the right', |
|
25: 'Road work', |
|
26: 'Traffic signals', |
|
27: 'Pedestrians', |
|
28: 'Children crossing', |
|
29: 'Bicycles crossing', |
|
30: 'Beware of ice/snow', |
|
31: 'Wild animals crossing', |
|
32: 'End of all speed and passing limits', |
|
33: 'Turn right ahead', |
|
34: 'Turn left ahead', |
|
35: 'Ahead only', |
|
36: 'Go straight or right', |
|
37: 'Go straight or left', |
|
38: 'Keep right', |
|
39: 'Keep left', |
|
40: 'Roundabout mandatory', |
|
41: 'End of no passing', |
|
42: 'End of no passing by vehicles over 3.5 metric tons' |
|
} |
|
|
|
def predict(image): |
|
image = np.array(image) |
|
image = cv2.resize(image, (32, 32)) |
|
image = image / 255.0 |
|
image = np.expand_dims(image, axis=0) |
|
|
|
predictions = model.predict(image) |
|
class_idx = np.argmax(predictions, axis=1)[0] |
|
return classes[class_idx] |
|
|
|
|
|
st.title("Traffic Sign Classifier") |
|
st.write("Esta aplicación clasifica señales de tráfico usando un modelo de CNN.") |
|
|
|
|
|
st.header("Ejemplos de Imágenes del Conjunto de Datos") |
|
fig, axes = plt.subplots(2, 5, figsize=(15, 6)) |
|
axes = axes.ravel() |
|
for i in range(10): |
|
idx = np.random.randint(0, len(train_images)) |
|
axes[i].imshow(train_images[idx]) |
|
axes[i].set_title(f"Clase: {train_labels[idx]}") |
|
axes[i].axis('off') |
|
st.pyplot(fig) |
|
|
|
|
|
st.header("Carga tu Propia Imagen de Señal de Tráfico") |
|
uploaded_file = st.file_uploader("Elige una imagen...", type=["jpg", "jpeg", "png"]) |
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption='Imagen Cargada', use_column_width=True) |
|
st.write("") |
|
st.write("Clasificando...") |
|
label = predict(image) |
|
st.write(f"Esta señal es: {label}") |