File size: 2,152 Bytes
cabc13e 5223af2 cabc13e 3b9f116 cabc13e 3b9f116 cabc13e 3b9f116 cabc13e 3b9f116 28f1a2b 9646a7d cabc13e 3b9f116 cabc13e 3b9f116 9646a7d 3b9f116 9646a7d 3b9f116 fa4d0f7 9646a7d fa4d0f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import pandas as pd
import numpy as np
import tensorflow as tf
from keras.models import load_model
import gradio as gr
# Define the classes and labels
classes = [
'car', 'house', 'wine bottle', 'chair', 'table',
'tree', 'camera', 'fish', 'rain', 'clock', 'hat'
]
labels = {name: index for index, name in enumerate(classes)}
num_classes = len(classes)
# Load the model
from keras.models import load_model
model_path = 'sketch_recognition_model_cnn.h5' # Ensure this path is correct
try:
model = load_model(model_path)
except Exception as e:
raise RuntimeError(f"Failed to load model from {model_path}: {e}")
# Predict function for the interface
def predict_fn(image):
"""
Predict the class of a drawn image.
Args:
image: The input image drawn by the user.
Returns:
The predicted class name.
"""
try:
# Extract the image data from the input dictionary
image_data = image['image'] if isinstance(image, dict) else image
# Preprocessing the image
resized_image = tf.image.resize(image_data, (28, 28)) # Resize to (28, 28)
grayscale_image = tf.image.rgb_to_grayscale(resized_image) # Convert to grayscale
image_array = np.array(grayscale_image) / 255.0 # Normalize the image
# Prepare image for model input
image_array = image_array.reshape(1, 28, 28, 1) # Add batch dimension
predictions = model.predict(image_array).reshape(num_classes) # Reshape to 1D
# Get the predicted class index
predicted_index = tf.argmax(predictions).numpy() # Convert to numpy
class_name = classes[predicted_index] # Get class name
return class_name
except Exception as e:
return f"Error in prediction: {str(e)}"
# Gradio application interface
gr.Interface(
fn=predict_fn,
inputs="paint",
outputs="label",
title="DoodleDecoder",
description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat",
interpretation='default', # Add the interpretation parameter here
article="Draw large with thick stroke."
).launch()
|