File size: 3,933 Bytes
d5541c9
9a43209
d5541c9
9a43209
 
 
d5541c9
 
 
9a43209
d5541c9
9a43209
d5541c9
6642c17
d5541c9
9a43209
d5541c9
 
 
9a43209
 
 
 
 
 
d5541c9
9a43209
 
 
 
d5541c9
6642c17
d5541c9
9a43209
d5541c9
 
 
9a43209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5541c9
 
 
 
 
 
9a43209
 
 
 
 
 
 
 
 
 
 
 
 
d5541c9
9a43209
 
 
 
 
 
 
 
 
 
 
 
d5541c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import torch
from transformers import AutoTokenizer, AutoModelForImageClassification
from huggingface_hub import hf_hub_download
from PIL import Image
import pickle

# Cargar dataset
dataset_repo = "VicGerardoPR/Traffic_sign_dataset"
dataset_file = "train.p"

dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file)
with open(dataset_path, 'rb') as file:
    train_data = pickle.load(file)

train_images, train_labels = train_data['features'], train_data['labels']

# Preprocesar datos
def preprocess_data(images, labels):
    images = images.astype('float32') / 255.0
    labels = pd.Series(labels).astype('category').cat.codes.to_numpy()
    return images, labels

train_images, train_labels = preprocess_data(train_images, train_labels)

# Cargar modelo
model_repo = "VicGerardoPR/Traffic_sign_model"
model_file = "traffic_sign_classifier.h5"

model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
model = torch.load(model_path, map_location=torch.device('cpu'))
model.eval()

# Diccionario de clases de señales de tráfico
classes = {
    0: 'Speed limit (20km/h)',
    1: 'Speed limit (30km/h)',
    2: 'Speed limit (50km/h)',
    3: 'Speed limit (60km/h)',
    4: 'Speed limit (70km/h)',
    5: 'Speed limit (80km/h)',
    6: 'End of speed limit (80km/h)',
    7: 'Speed limit (100km/h)',
    8: 'Speed limit (120km/h)',
    9: 'No passing',
    10: 'No passing for vehicles over 3.5 metric tons',
    11: 'Right-of-way at the next intersection',
    12: 'Priority road',
    13: 'Yield',
    14: 'Stop',
    15: 'No vehicles',
    16: 'Vehicles over 3.5 metric tons prohibited',
    17: 'No entry',
    18: 'General caution',
    19: 'Dangerous curve to the left',
    20: 'Dangerous curve to the right',
    21: 'Double curve',
    22: 'Bumpy road',
    23: 'Slippery road',
    24: 'Road narrows on the right',
    25: 'Road work',
    26: 'Traffic signals',
    27: 'Pedestrians',
    28: 'Children crossing',
    29: 'Bicycles crossing',
    30: 'Beware of ice/snow',
    31: 'Wild animals crossing',
    32: 'End of all speed and passing limits',
    33: 'Turn right ahead',
    34: 'Turn left ahead',
    35: 'Ahead only',
    36: 'Go straight or right',
    37: 'Go straight or left',
    38: 'Keep right',
    39: 'Keep left',
    40: 'Roundabout mandatory',
    41: 'End of no passing',
    42: 'End of no passing by vehicles over 3.5 metric tons'
}

def predict(image):
    image = np.array(image)
    image = cv2.resize(image, (32, 32))
    image = image / 255.0
    image = np.expand_dims(image, axis=0)
    image = np.transpose(image, (0, 3, 1, 2))  # Reordenar dimensiones para PyTorch
    image = torch.tensor(image).float()

    with torch.no_grad():
        predictions = model(image)
        class_idx = predictions.argmax().item()
    return classes[class_idx]

# Título y descripción de la aplicación
st.title("Traffic Sign Classifier")
st.write("Esta aplicación clasifica señales de tráfico usando un modelo de CNN.")

# Mostrar ejemplos de imágenes del conjunto de datos
st.header("Ejemplos de Imágenes del Conjunto de Datos")
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.ravel()
for i in range(10):
    idx = np.random.randint(0, len(train_images))
    axes[i].imshow(train_images[idx])
    axes[i].set_title(f"Clase: {train_labels[idx]}")
    axes[i].axis('off')
st.pyplot(fig)

# Permitir al usuario cargar una imagen
st.header("Carga tu Propia Imagen de Señal de Tráfico")
uploaded_file = st.file_uploader("Elige una imagen...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Imagen Cargada', use_column_width=True)
    st.write("")
    st.write("Clasificando...")
    label = predict(image)
    st.write(f"Esta señal es: {label}")