VicGerardoPR commited on
Commit
25a9506
verified
1 Parent(s): 0b3b5f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -9
app.py CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  import cv2
8
- import torch
9
  from huggingface_hub import hf_hub_download
10
  from PIL import Image
11
  import pickle
@@ -49,13 +49,11 @@ model_file = "traffic_sign_classifier.h5"
49
 
50
  try:
51
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file, repo_type="model")
52
- model = torch.load(model_path, map_location=torch.device('cpu'))
53
- model.eval()
54
  except Exception as e:
55
  st.error(f"Error al cargar el modelo: {e}")
56
  st.stop()
57
 
58
-
59
  # Diccionario de clases de se帽ales de tr谩fico
60
  classes = {
61
  0: 'Speed limit (20km/h)',
@@ -108,12 +106,9 @@ def predict(image):
108
  image = cv2.resize(image, (32, 32))
109
  image = image / 255.0
110
  image = np.expand_dims(image, axis=0)
111
- image = np.transpose(image, (0, 3, 1, 2)) # Reordenar dimensiones para PyTorch
112
- image = torch.tensor(image).float()
113
 
114
- with torch.no_grad():
115
- predictions = model(image)
116
- class_idx = predictions.argmax().item()
117
  return classes[class_idx]
118
 
119
  # T铆tulo y descripci贸n de la aplicaci贸n
 
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  import cv2
8
+ import tensorflow as tf
9
  from huggingface_hub import hf_hub_download
10
  from PIL import Image
11
  import pickle
 
49
 
50
  try:
51
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file, repo_type="model")
52
+ model = tf.keras.models.load_model(model_path)
 
53
  except Exception as e:
54
  st.error(f"Error al cargar el modelo: {e}")
55
  st.stop()
56
 
 
57
  # Diccionario de clases de se帽ales de tr谩fico
58
  classes = {
59
  0: 'Speed limit (20km/h)',
 
106
  image = cv2.resize(image, (32, 32))
107
  image = image / 255.0
108
  image = np.expand_dims(image, axis=0)
 
 
109
 
110
+ predictions = model.predict(image)
111
+ class_idx = np.argmax(predictions, axis=1)[0]
 
112
  return classes[class_idx]
113
 
114
  # T铆tulo y descripci贸n de la aplicaci贸n