File size: 4,284 Bytes
0b3b5f6
d5541c9
9a43209
d5541c9
9a43209
 
 
25a9506
0b3b5f6
9a43209
d5541c9
9a43209
0b3b5f6
 
 
 
 
 
 
 
 
 
1e3e204
d5541c9
6642c17
d5541c9
9a43209
1e3e204
 
 
 
 
 
 
9a43209
 
 
 
 
 
d5541c9
9a43209
 
 
 
d5541c9
6642c17
d5541c9
9a43209
1e3e204
 
25a9506
1e3e204
 
 
 
9a43209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5541c9
25a9506
 
9a43209
 
 
 
 
 
 
 
 
 
 
 
 
d5541c9
9a43209
 
 
 
 
 
 
 
 
 
 
 
d5541c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import tensorflow as tf
from huggingface_hub import hf_hub_download
from PIL import Image
import pickle

# Obtener el token de Hugging Face desde las variables de entorno
hf_token = os.getenv("HF_API_TOKEN")

if hf_token is None:
    st.error("El token de Hugging Face no está configurado. Por favor, configura el secreto 'HF_API_TOKEN'.")
    st.stop()

# Autenticarse con el token
from huggingface_hub import login
login(token=hf_token)

# Cargar dataset
dataset_repo = "VicGerardoPR/Traffic_sign_dataset"
dataset_file = "train.p"

try:
    dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file, repo_type="dataset")
    with open(dataset_path, 'rb') as file:
        train_data = pickle.load(file)
except Exception as e:
    st.error(f"Error al cargar el dataset: {e}")
    st.stop()

train_images, train_labels = train_data['features'], train_data['labels']

# Preprocesar datos
def preprocess_data(images, labels):
    images = images.astype('float32') / 255.0
    labels = pd.Series(labels).astype('category').cat.codes.to_numpy()
    return images, labels

train_images, train_labels = preprocess_data(train_images, train_labels)

# Cargar modelo
model_repo = "VicGerardoPR/Traffic_sign_model"
model_file = "traffic_sign_classifier.h5"

try:
    model_path = hf_hub_download(repo_id=model_repo, filename=model_file, repo_type="model")
    model = tf.keras.models.load_model(model_path)
except Exception as e:
    st.error(f"Error al cargar el modelo: {e}")
    st.stop()

# Diccionario de clases de señales de tráfico
classes = {
    0: 'Speed limit (20km/h)',
    1: 'Speed limit (30km/h)',
    2: 'Speed limit (50km/h)',
    3: 'Speed limit (60km/h)',
    4: 'Speed limit (70km/h)',
    5: 'Speed limit (80km/h)',
    6: 'End of speed limit (80km/h)',
    7: 'Speed limit (100km/h)',
    8: 'Speed limit (120km/h)',
    9: 'No passing',
    10: 'No passing for vehicles over 3.5 metric tons',
    11: 'Right-of-way at the next intersection',
    12: 'Priority road',
    13: 'Yield',
    14: 'Stop',
    15: 'No vehicles',
    16: 'Vehicles over 3.5 metric tons prohibited',
    17: 'No entry',
    18: 'General caution',
    19: 'Dangerous curve to the left',
    20: 'Dangerous curve to the right',
    21: 'Double curve',
    22: 'Bumpy road',
    23: 'Slippery road',
    24: 'Road narrows on the right',
    25: 'Road work',
    26: 'Traffic signals',
    27: 'Pedestrians',
    28: 'Children crossing',
    29: 'Bicycles crossing',
    30: 'Beware of ice/snow',
    31: 'Wild animals crossing',
    32: 'End of all speed and passing limits',
    33: 'Turn right ahead',
    34: 'Turn left ahead',
    35: 'Ahead only',
    36: 'Go straight or right',
    37: 'Go straight or left',
    38: 'Keep right',
    39: 'Keep left',
    40: 'Roundabout mandatory',
    41: 'End of no passing',
    42: 'End of no passing by vehicles over 3.5 metric tons'
}

def predict(image):
    image = np.array(image)
    image = cv2.resize(image, (32, 32))
    image = image / 255.0
    image = np.expand_dims(image, axis=0)

    predictions = model.predict(image)
    class_idx = np.argmax(predictions, axis=1)[0]
    return classes[class_idx]

# Título y descripción de la aplicación
st.title("Traffic Sign Classifier")
st.write("Esta aplicación clasifica señales de tráfico usando un modelo de CNN.")

# Mostrar ejemplos de imágenes del conjunto de datos
st.header("Ejemplos de Imágenes del Conjunto de Datos")
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.ravel()
for i in range(10):
    idx = np.random.randint(0, len(train_images))
    axes[i].imshow(train_images[idx])
    axes[i].set_title(f"Clase: {train_labels[idx]}")
    axes[i].axis('off')
st.pyplot(fig)

# Permitir al usuario cargar una imagen
st.header("Carga tu Propia Imagen de Señal de Tráfico")
uploaded_file = st.file_uploader("Elige una imagen...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Imagen Cargada', use_column_width=True)
    st.write("")
    st.write("Clasificando...")
    label = predict(image)
    st.write(f"Esta señal es: {label}")