import tensorflow as tf import numpy as np import gradio as gr from tensorflow.keras.preprocessing import image from sklearn.preprocessing import LabelEncoder import matplotlib.pyplot as plt from PIL import Image from tensorflow.keras.metrics import MeanSquaredError # Dictionnaire des sous-catégories subcategory_dict = { "Furniture": ["Home Decor"], "Home Decor": [ "Poufs & Ottomans", "Rugs", "Antique items", "Brass Lamps", "Candle Holders", "Pottery", "Kilim poufs", "Pillow Covers", "Wall Decor", "Straw Lamps" ], # Ajoutez d'autres catégories ici } # Fonction pour charger et prétraiter l'image def preprocess_image(img): img = img.resize((224, 224)) # Redimensionner img_array = np.array(img) / 255.0 # Normaliser img_array = np.expand_dims(img_array, axis=0) # Ajouter une dimension batch return img_array # Fonction pour prédire la catégorie, le prix et la sous-catégorie def predict_image(img): # Prétraiter l'image img_array = preprocess_image(img) # Faire les prédictions category_pred, price_pred = model.predict(img_array) # Décoder la catégorie category_pred_class = np.argmax(category_pred, axis=1)[0] category_name = label_encoder.inverse_transform([category_pred_class])[0] # Trouver les sous-catégories correspondantes subcategories = subcategory_dict.get(category_name, []) # Préparer les résultats results = { "Category": category_name, "Price ($)": f"{price_pred[0][0]:.2f}", "Subcategories": subcategories } return results # Charger le modèle pré-entraîné # Assurez-vous que le chemin du modèle et de l'encodeur sont corrects model = tf.keras.models.load_model('trained_model.h5', custom_objects={'mse': MeanSquaredError()}) label_encoder = LabelEncoder() label_encoder.classes_ = np.load('path_to_label_encoder_classes.npy') # Interface Gradio interface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil"), outputs=[ gr.Label(label="Category"), gr.Text(label="Price ($)"), gr.Text(label="Subcategories") ], title="Image Classification with TensorFlow", description="Upload an image to predict its category, price, and subcategories." ) # Lancer l'interface interface.launch()