Apex-X commited on
Commit
9646a7d
·
verified ·
1 Parent(s): fa4d0f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -14,8 +14,11 @@ labels = {name: index for index, name in enumerate(classes)}
14
  num_classes = len(classes)
15
 
16
  # Load the model
17
- from keras.models import load_model
18
- model = load_model('sketch_recogination_model_cnn.h5')
 
 
 
19
 
20
  # Predict function for the interface
21
  def predict_fn(image):
@@ -29,8 +32,11 @@ def predict_fn(image):
29
  The predicted class name.
30
  """
31
  try:
 
 
 
32
  # Preprocessing the image
33
- resized_image = tf.image.resize(image, (28, 28)) # Resize to (28, 28)
34
  grayscale_image = tf.image.rgb_to_grayscale(resized_image) # Convert to grayscale
35
  image_array = np.array(grayscale_image) / 255.0 # Normalize the image
36
 
@@ -47,14 +53,12 @@ def predict_fn(image):
47
  return f"Error in prediction: {str(e)}"
48
 
49
  # Gradio application interface
50
-
51
  gr.Interface(
52
  fn=predict_fn,
53
  inputs="paint",
54
  outputs="label",
55
  title="DoodleDecoder",
56
  description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat",
57
- interpretation='default',
58
  article="Draw large with thick stroke."
59
  ).launch()
60
-
 
14
  num_classes = len(classes)
15
 
16
  # Load the model
17
+ model_path = 'sketch_recognition_model_cnn.h5' # Ensure this path is correct
18
+ try:
19
+ model = load_model(model_path)
20
+ except Exception as e:
21
+ raise RuntimeError(f"Failed to load model from {model_path}: {e}")
22
 
23
  # Predict function for the interface
24
  def predict_fn(image):
 
32
  The predicted class name.
33
  """
34
  try:
35
+ # Extract the image data from the input dictionary
36
+ image_data = image['image'] if isinstance(image, dict) else image
37
+
38
  # Preprocessing the image
39
+ resized_image = tf.image.resize(image_data, (28, 28)) # Resize to (28, 28)
40
  grayscale_image = tf.image.rgb_to_grayscale(resized_image) # Convert to grayscale
41
  image_array = np.array(grayscale_image) / 255.0 # Normalize the image
42
 
 
53
  return f"Error in prediction: {str(e)}"
54
 
55
  # Gradio application interface
 
56
  gr.Interface(
57
  fn=predict_fn,
58
  inputs="paint",
59
  outputs="label",
60
  title="DoodleDecoder",
61
  description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat",
62
+ interpretation='default', # Add the interpretation parameter here
63
  article="Draw large with thick stroke."
64
  ).launch()