saaara commited on
Commit
8728306
·
verified ·
1 Parent(s): 0d6ba95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -34,7 +34,7 @@ def predict_image(img):
34
  category_pred, price_pred = model.predict(img_array)
35
 
36
  # Décoder la catégorie
37
- category_pred_class = np.argmax(category_pred, axis=1)[0]
38
  category_name = label_encoder.inverse_transform([category_pred_class])[0]
39
 
40
  # Trouver les sous-catégories correspondantes
@@ -44,16 +44,13 @@ def predict_image(img):
44
  results = {
45
  "Category": category_name,
46
  "Price ($)": f"{price_pred[0][0]:.2f}",
47
- "Subcategories": subcategories
48
  }
49
  return results
50
 
51
  # Charger le modèle pré-entraîné
52
- # Assurez-vous que le chemin du modèle et de l'encodeur sont corrects
53
  model = tf.keras.models.load_model('trained_model.h5', custom_objects={'mse': MeanSquaredError()})
54
 
55
- #label_encoder = LabelEncoder()
56
- #label_encoder.classes_ = np.load('path_to_label_encoder_classes.npy')
57
 
58
  # Interface Gradio
59
  interface = gr.Interface(
 
34
  category_pred, price_pred = model.predict(img_array)
35
 
36
  # Décoder la catégorie
37
+ category_pred_class = np.argmax(category_pred, axis=1)[0] # La classe avec la plus haute probabilité
38
  category_name = label_encoder.inverse_transform([category_pred_class])[0]
39
 
40
  # Trouver les sous-catégories correspondantes
 
44
  results = {
45
  "Category": category_name,
46
  "Price ($)": f"{price_pred[0][0]:.2f}",
47
+ "Subcategories": ", ".join(subcategories) if subcategories else "No subcategories"
48
  }
49
  return results
50
 
51
  # Charger le modèle pré-entraîné
 
52
  model = tf.keras.models.load_model('trained_model.h5', custom_objects={'mse': MeanSquaredError()})
53
 
 
 
54
 
55
  # Interface Gradio
56
  interface = gr.Interface(