VicGerardoPR commited on
Commit
1e3e204
·
verified ·
1 Parent(s): 6642c17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -5,18 +5,22 @@ import matplotlib.pyplot as plt
5
  import seaborn as sns
6
  import cv2
7
  import torch
8
- from transformers import AutoTokenizer, AutoModelForImageClassification
9
- from huggingface_hub import hf_hub_download
10
  from PIL import Image
11
  import pickle
12
 
 
13
  # Cargar dataset
14
  dataset_repo = "VicGerardoPR/Traffic_sign_dataset"
15
  dataset_file = "train.p"
16
 
17
- dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file)
18
- with open(dataset_path, 'rb') as file:
19
- train_data = pickle.load(file)
 
 
 
 
20
 
21
  train_images, train_labels = train_data['features'], train_data['labels']
22
 
@@ -32,9 +36,14 @@ train_images, train_labels = preprocess_data(train_images, train_labels)
32
  model_repo = "VicGerardoPR/Traffic_sign_model"
33
  model_file = "traffic_sign_classifier.h5"
34
 
35
- model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
36
- model = torch.load(model_path, map_location=torch.device('cpu'))
37
- model.eval()
 
 
 
 
 
38
 
39
  # Diccionario de clases de señales de tráfico
40
  classes = {
 
5
  import seaborn as sns
6
  import cv2
7
  import torch
8
+ from huggingface_hub import hf_hub_download, login
 
9
  from PIL import Image
10
  import pickle
11
 
12
+
13
  # Cargar dataset
14
  dataset_repo = "VicGerardoPR/Traffic_sign_dataset"
15
  dataset_file = "train.p"
16
 
17
+ try:
18
+ dataset_path = hf_hub_download(repo_id=dataset_repo, filename=dataset_file, repo_type="dataset")
19
+ with open(dataset_path, 'rb') as file:
20
+ train_data = pickle.load(file)
21
+ except Exception as e:
22
+ st.error(f"Error al cargar el dataset: {e}")
23
+ st.stop()
24
 
25
  train_images, train_labels = train_data['features'], train_data['labels']
26
 
 
36
  model_repo = "VicGerardoPR/Traffic_sign_model"
37
  model_file = "traffic_sign_classifier.h5"
38
 
39
+ try:
40
+ model_path = hf_hub_download(repo_id=model_repo, filename=model_file, repo_type="model")
41
+ model = torch.load(model_path, map_location=torch.device('cpu'))
42
+ model.eval()
43
+ except Exception as e:
44
+ st.error(f"Error al cargar el modelo: {e}")
45
+ st.stop()
46
+
47
 
48
  # Diccionario de clases de señales de tráfico
49
  classes = {