VicGerardoPR's picture
Update app.py
25a9506 verified
import os
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import tensorflow as tf
from huggingface_hub import hf_hub_download
from PIL import Image
import pickle
# Obtener el token de Hugging Face desde las variables de entorno
hf_token = os.getenv("HF_API_TOKEN")
if hf_token is None:
st.error("El token de Hugging Face no está configurado. Por favor, configura el secreto 'HF_API_TOKEN'.")
st.stop()
# Autenticarse con el token
from huggingface_hub import login
login(token=hf_token)
# Cargar dataset
dataset_repo = "VicGerardoPR/Traffic_sign_dataset"
dataset_file = "train.p"
try:
dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file, repo_type="dataset")
with open(dataset_path, 'rb') as file:
train_data = pickle.load(file)
except Exception as e:
st.error(f"Error al cargar el dataset: {e}")
st.stop()
train_images, train_labels = train_data['features'], train_data['labels']
# Preprocesar datos
def preprocess_data(images, labels):
images = images.astype('float32') / 255.0
labels = pd.Series(labels).astype('category').cat.codes.to_numpy()
return images, labels
train_images, train_labels = preprocess_data(train_images, train_labels)
# Cargar modelo
model_repo = "VicGerardoPR/Traffic_sign_model"
model_file = "traffic_sign_classifier.h5"
try:
model_path = hf_hub_download(repo_id=model_repo, filename=model_file, repo_type="model")
model = tf.keras.models.load_model(model_path)
except Exception as e:
st.error(f"Error al cargar el modelo: {e}")
st.stop()
# Diccionario de clases de señales de tráfico
classes = {
0: 'Speed limit (20km/h)',
1: 'Speed limit (30km/h)',
2: 'Speed limit (50km/h)',
3: 'Speed limit (60km/h)',
4: 'Speed limit (70km/h)',
5: 'Speed limit (80km/h)',
6: 'End of speed limit (80km/h)',
7: 'Speed limit (100km/h)',
8: 'Speed limit (120km/h)',
9: 'No passing',
10: 'No passing for vehicles over 3.5 metric tons',
11: 'Right-of-way at the next intersection',
12: 'Priority road',
13: 'Yield',
14: 'Stop',
15: 'No vehicles',
16: 'Vehicles over 3.5 metric tons prohibited',
17: 'No entry',
18: 'General caution',
19: 'Dangerous curve to the left',
20: 'Dangerous curve to the right',
21: 'Double curve',
22: 'Bumpy road',
23: 'Slippery road',
24: 'Road narrows on the right',
25: 'Road work',
26: 'Traffic signals',
27: 'Pedestrians',
28: 'Children crossing',
29: 'Bicycles crossing',
30: 'Beware of ice/snow',
31: 'Wild animals crossing',
32: 'End of all speed and passing limits',
33: 'Turn right ahead',
34: 'Turn left ahead',
35: 'Ahead only',
36: 'Go straight or right',
37: 'Go straight or left',
38: 'Keep right',
39: 'Keep left',
40: 'Roundabout mandatory',
41: 'End of no passing',
42: 'End of no passing by vehicles over 3.5 metric tons'
}
def predict(image):
image = np.array(image)
image = cv2.resize(image, (32, 32))
image = image / 255.0
image = np.expand_dims(image, axis=0)
predictions = model.predict(image)
class_idx = np.argmax(predictions, axis=1)[0]
return classes[class_idx]
# Título y descripción de la aplicación
st.title("Traffic Sign Classifier")
st.write("Esta aplicación clasifica señales de tráfico usando un modelo de CNN.")
# Mostrar ejemplos de imágenes del conjunto de datos
st.header("Ejemplos de Imágenes del Conjunto de Datos")
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.ravel()
for i in range(10):
idx = np.random.randint(0, len(train_images))
axes[i].imshow(train_images[idx])
axes[i].set_title(f"Clase: {train_labels[idx]}")
axes[i].axis('off')
st.pyplot(fig)
# Permitir al usuario cargar una imagen
st.header("Carga tu Propia Imagen de Señal de Tráfico")
uploaded_file = st.file_uploader("Elige una imagen...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Imagen Cargada', use_column_width=True)
st.write("")
st.write("Clasificando...")
label = predict(image)
st.write(f"Esta señal es: {label}")