Apex-X commited on
Commit
3b9f116
·
verified ·
1 Parent(s): 5223af2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -52
app.py CHANGED
@@ -4,62 +4,57 @@ import tensorflow as tf
4
  from keras.models import load_model
5
  import gradio as gr
6
 
7
- # classes:
8
  classes = [
9
- 'car',
10
- 'house',
11
- 'wine bottle',
12
- 'chair',
13
- 'table',
14
- 'tree',
15
- 'camera',
16
- 'fish',
17
- 'rain',
18
- 'clock',
19
- 'hat'
20
  ]
21
 
22
- # labels :
23
- labels = {
24
- 'car': 0,
25
- 'house': 1,
26
- 'wine bottle': 2,
27
- 'chair': 3,
28
- 'table': 4,
29
- 'tree': 5,
30
- 'camera': 6,
31
- 'fish': 7,
32
- 'rain': 8,
33
- 'clock': 9,
34
- 'hat': 10
35
- }
36
-
37
  num_classes = len(classes)
38
 
39
- # load the model:
40
- from keras.models import load_model
41
- model = load_model('sketch_recogination_model_cnn.h5')
 
 
 
42
 
43
- # Predict function for interface:
44
  def predict_fn(image):
45
-
46
- # preprocessing the size:
47
- resized_image = tf.image.resize(image, (28, 28)) # Resize image to (28, 28)
48
- grayscale_image = tf.image.rgb_to_grayscale(resized_image) # Convert image to grayscale
49
-
50
- image = np.array(grayscale_image)
51
-
52
- # model requirements:
53
- image = image.reshape(1,28,28,1)
54
- label = tf.constant(model.predict(image).reshape(num_classes)) # giving 2D output so 1D
55
-
56
- # predict:
57
- predicted_index = tf.argmax(label)
58
- class_name = [name for name, index in labels.items() if predicted_index == index][0]
59
- return class_name
60
-
61
-
62
- # application interface:
63
- import gradio as gr
64
- gr.Interface(fn=predict_fn, inputs="paint", outputs="label", title="DoodleDecoder", description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat", interpretation='default', article="Draw large with thick stroke.").launch()
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from keras.models import load_model
5
  import gradio as gr
6
 
7
+ # Define the classes and labels
8
  classes = [
9
+ 'car', 'house', 'wine bottle', 'chair', 'table',
10
+ 'tree', 'camera', 'fish', 'rain', 'clock', 'hat'
 
 
 
 
 
 
 
 
 
11
  ]
12
 
13
+ labels = {name: index for index, name in enumerate(classes)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  num_classes = len(classes)
15
 
16
+ # Load the model
17
+ model_path = 'sketch_recognition_model_cnn.h5' # Make sure this path is correct
18
+ try:
19
+ model = load_model(model_path)
20
+ except Exception as e:
21
+ raise RuntimeError(f"Failed to load model from {model_path}: {e}")
22
 
23
+ # Predict function for the interface
24
  def predict_fn(image):
25
+ """
26
+ Predict the class of a drawn image.
27
+
28
+ Args:
29
+ image: The input image drawn by the user.
30
+
31
+ Returns:
32
+ The predicted class name.
33
+ """
34
+ try:
35
+ # Preprocessing the image
36
+ resized_image = tf.image.resize(image, (28, 28)) # Resize to (28, 28)
37
+ grayscale_image = tf.image.rgb_to_grayscale(resized_image) # Convert to grayscale
38
+ image_array = np.array(grayscale_image) / 255.0 # Normalize the image
39
+
40
+ # Prepare image for model input
41
+ image_array = image_array.reshape(1, 28, 28, 1) # Add batch dimension
42
+ predictions = model.predict(image_array).reshape(num_classes) # Reshape to 1D
43
+
44
+ # Get the predicted class index
45
+ predicted_index = tf.argmax(predictions).numpy() # Convert to numpy
46
+ class_name = classes[predicted_index] # Get class name
47
+
48
+ return class_name
49
+ except Exception as e:
50
+ return f"Error in prediction: {str(e)}"
51
+
52
+ # Gradio application interface
53
+ gr.Interface(
54
+ fn=predict_fn,
55
+ inputs="paint",
56
+ outputs="label",
57
+ title="DoodleDecoder",
58
+ description="Draw something from: Car, House, Wine bottle, Chair, Table, Tree, Camera, Fish, Rain, Clock, Hat",
59
+ article="Draw large with thick stroke."
60
+ ).launch()