Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
import requests | |
from PIL import Image | |
# Load the model from Hugging Face | |
model_url = "https://huggingface.co/sebastiancgeorge/ensembled_waste_classification/blob/main/ensemble_waste_classifier%20(1).keras" | |
model_path = "ensemble_waste_classifier (1).keras" | |
# Download the model if not available | |
response = requests.get(model_url, allow_redirects=True) | |
open(model_path, "wb").write(response.content) | |
# Load the model | |
model = tf.keras.models.load_model(model_path) | |
# Define class labels | |
CLASS_LABELS = ["Cardboard", "Glass", "Metal", "Paper", "Plastic", "Trash"] | |
# Preprocess the input image | |
def preprocess_image(image): | |
image = image.resize((224, 224)) # Resize to model input size | |
image = np.array(image) / 255.0 # Normalize | |
image = np.expand_dims(image, axis=0) # Add batch dimension | |
return image | |
# Define prediction function | |
def classify_image(image): | |
image = preprocess_image(image) | |
predictions = model.predict(image)[0] | |
confidence_scores = {CLASS_LABELS[i]: float(predictions[i]) for i in range(len(CLASS_LABELS))} | |
return confidence_scores | |
# Create Gradio Interface | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(), | |
title="Waste Classification Model", | |
description="Upload an image to classify waste into categories: Cardboard, Glass, Metal, Paper, Plastic, Trash.", | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch() | |