Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
import numpy as np | |
import gradio as gr | |
from tensorflow.keras.preprocessing import image | |
from sklearn.preprocessing import LabelEncoder | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from tensorflow.keras.metrics import MeanSquaredError | |
# Dictionnaire des sous-catégories | |
subcategory_dict = { | |
"Furniture": ["Home Decor"], | |
"Home Decor": [ | |
"Poufs & Ottomans", "Rugs", "Antique items", "Brass Lamps", | |
"Candle Holders", "Pottery", "Kilim poufs", "Pillow Covers", | |
"Wall Decor", "Straw Lamps" | |
], | |
# Ajoutez d'autres catégories ici | |
} | |
# Fonction pour charger et prétraiter l'image | |
def preprocess_image(img): | |
img = img.resize((224, 224)) # Redimensionner | |
img_array = np.array(img) / 255.0 # Normaliser | |
img_array = np.expand_dims(img_array, axis=0) # Ajouter une dimension batch | |
return img_array | |
# Fonction pour prédire la catégorie, le prix et la sous-catégorie | |
def predict_image(img): | |
# Prétraiter l'image | |
img_array = preprocess_image(img) | |
# Faire les prédictions | |
category_pred, price_pred = model.predict(img_array) | |
# Décoder la catégorie | |
category_pred_class = np.argmax(category_pred, axis=1)[0] | |
category_name = label_encoder.inverse_transform([category_pred_class])[0] | |
# Trouver les sous-catégories correspondantes | |
subcategories = subcategory_dict.get(category_name, []) | |
# Préparer les résultats | |
results = { | |
"Category": category_name, | |
"Price ($)": f"{price_pred[0][0]:.2f}", | |
"Subcategories": subcategories | |
} | |
return results | |
# Charger le modèle pré-entraîné | |
# Assurez-vous que le chemin du modèle et de l'encodeur sont corrects | |
model = tf.keras.models.load_model('trained_model.h5', custom_objects={'mse': MeanSquaredError()}) | |
#label_encoder = LabelEncoder() | |
#label_encoder.classes_ = np.load('path_to_label_encoder_classes.npy') | |
# Interface Gradio | |
interface = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Label(label="Category"), | |
gr.Text(label="Price ($)"), | |
gr.Text(label="Subcategories") | |
], | |
title="Image Classification with TensorFlow", | |
description="Upload an image to predict its category, price, and subcategories." | |
) | |
# Lancer l'interface | |
interface.launch() | |