VicGerardoPR's picture
URL change
6642c17
raw
history blame
3.93 kB
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import torch
from transformers import AutoTokenizer, AutoModelForImageClassification
from huggingface_hub import hf_hub_download
from PIL import Image
import pickle
# Cargar dataset
dataset_repo = "VicGerardoPR/Traffic_sign_dataset"
dataset_file = "train.p"
dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file)
with open(dataset_path, 'rb') as file:
train_data = pickle.load(file)
train_images, train_labels = train_data['features'], train_data['labels']
# Preprocesar datos
def preprocess_data(images, labels):
images = images.astype('float32') / 255.0
labels = pd.Series(labels).astype('category').cat.codes.to_numpy()
return images, labels
train_images, train_labels = preprocess_data(train_images, train_labels)
# Cargar modelo
model_repo = "VicGerardoPR/Traffic_sign_model"
model_file = "traffic_sign_classifier.h5"
model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
model = torch.load(model_path, map_location=torch.device('cpu'))
model.eval()
# Diccionario de clases de señales de tráfico
classes = {
0: 'Speed limit (20km/h)',
1: 'Speed limit (30km/h)',
2: 'Speed limit (50km/h)',
3: 'Speed limit (60km/h)',
4: 'Speed limit (70km/h)',
5: 'Speed limit (80km/h)',
6: 'End of speed limit (80km/h)',
7: 'Speed limit (100km/h)',
8: 'Speed limit (120km/h)',
9: 'No passing',
10: 'No passing for vehicles over 3.5 metric tons',
11: 'Right-of-way at the next intersection',
12: 'Priority road',
13: 'Yield',
14: 'Stop',
15: 'No vehicles',
16: 'Vehicles over 3.5 metric tons prohibited',
17: 'No entry',
18: 'General caution',
19: 'Dangerous curve to the left',
20: 'Dangerous curve to the right',
21: 'Double curve',
22: 'Bumpy road',
23: 'Slippery road',
24: 'Road narrows on the right',
25: 'Road work',
26: 'Traffic signals',
27: 'Pedestrians',
28: 'Children crossing',
29: 'Bicycles crossing',
30: 'Beware of ice/snow',
31: 'Wild animals crossing',
32: 'End of all speed and passing limits',
33: 'Turn right ahead',
34: 'Turn left ahead',
35: 'Ahead only',
36: 'Go straight or right',
37: 'Go straight or left',
38: 'Keep right',
39: 'Keep left',
40: 'Roundabout mandatory',
41: 'End of no passing',
42: 'End of no passing by vehicles over 3.5 metric tons'
}
def predict(image):
image = np.array(image)
image = cv2.resize(image, (32, 32))
image = image / 255.0
image = np.expand_dims(image, axis=0)
image = np.transpose(image, (0, 3, 1, 2)) # Reordenar dimensiones para PyTorch
image = torch.tensor(image).float()
with torch.no_grad():
predictions = model(image)
class_idx = predictions.argmax().item()
return classes[class_idx]
# Título y descripción de la aplicación
st.title("Traffic Sign Classifier")
st.write("Esta aplicación clasifica señales de tráfico usando un modelo de CNN.")
# Mostrar ejemplos de imágenes del conjunto de datos
st.header("Ejemplos de Imágenes del Conjunto de Datos")
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.ravel()
for i in range(10):
idx = np.random.randint(0, len(train_images))
axes[i].imshow(train_images[idx])
axes[i].set_title(f"Clase: {train_labels[idx]}")
axes[i].axis('off')
st.pyplot(fig)
# Permitir al usuario cargar una imagen
st.header("Carga tu Propia Imagen de Señal de Tráfico")
uploaded_file = st.file_uploader("Elige una imagen...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Imagen Cargada', use_column_width=True)
st.write("")
st.write("Clasificando...")
label = predict(image)
st.write(f"Esta señal es: {label}")