Update app.py
Browse files
app.py
CHANGED
@@ -14,8 +14,11 @@ labels = {name: index for index, name in enumerate(classes)}
|
|
14 |
num_classes = len(classes)
|
15 |
|
16 |
# Load the model
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
# Predict function for the interface
|
21 |
def predict_fn(image):
|
@@ -29,8 +32,11 @@ def predict_fn(image):
|
|
29 |
The predicted class name.
|
30 |
"""
|
31 |
try:
|
|
|
|
|
|
|
32 |
# Preprocessing the image
|
33 |
-
resized_image = tf.image.resize(
|
34 |
grayscale_image = tf.image.rgb_to_grayscale(resized_image) # Convert to grayscale
|
35 |
image_array = np.array(grayscale_image) / 255.0 # Normalize the image
|
36 |
|
@@ -47,14 +53,12 @@ def predict_fn(image):
|
|
47 |
return f"Error in prediction: {str(e)}"
|
48 |
|
49 |
# Gradio application interface
|
50 |
-
|
51 |
gr.Interface(
|
52 |
fn=predict_fn,
|
53 |
inputs="paint",
|
54 |
outputs="label",
|
55 |
title="DoodleDecoder",
|
56 |
description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat",
|
57 |
-
interpretation='default',
|
58 |
article="Draw large with thick stroke."
|
59 |
).launch()
|
60 |
-
|
|
|
14 |
num_classes = len(classes)
|
15 |
|
16 |
# Load the model
|
17 |
+
model_path = 'sketch_recognition_model_cnn.h5' # Ensure this path is correct
|
18 |
+
try:
|
19 |
+
model = load_model(model_path)
|
20 |
+
except Exception as e:
|
21 |
+
raise RuntimeError(f"Failed to load model from {model_path}: {e}")
|
22 |
|
23 |
# Predict function for the interface
|
24 |
def predict_fn(image):
|
|
|
32 |
The predicted class name.
|
33 |
"""
|
34 |
try:
|
35 |
+
# Extract the image data from the input dictionary
|
36 |
+
image_data = image['image'] if isinstance(image, dict) else image
|
37 |
+
|
38 |
# Preprocessing the image
|
39 |
+
resized_image = tf.image.resize(image_data, (28, 28)) # Resize to (28, 28)
|
40 |
grayscale_image = tf.image.rgb_to_grayscale(resized_image) # Convert to grayscale
|
41 |
image_array = np.array(grayscale_image) / 255.0 # Normalize the image
|
42 |
|
|
|
53 |
return f"Error in prediction: {str(e)}"
|
54 |
|
55 |
# Gradio application interface
|
|
|
56 |
gr.Interface(
|
57 |
fn=predict_fn,
|
58 |
inputs="paint",
|
59 |
outputs="label",
|
60 |
title="DoodleDecoder",
|
61 |
description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat",
|
62 |
+
interpretation='default', # Add the interpretation parameter here
|
63 |
article="Draw large with thick stroke."
|
64 |
).launch()
|
|