dielz commited on
Commit
1c2a810
·
verified ·
1 Parent(s): 22a6775
Files changed (1) hide show
  1. app.py +22 -11
app.py CHANGED
@@ -16,22 +16,33 @@ def predict_image(image):
16
 
17
  predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
18
 
19
- # Top 3 classes
20
- top_3_indices = np.argsort(predictions.numpy(), axis=1)[0][-3:][::-1]
21
- top_3_labels = [labels[i] for i in top_3_indices]
22
- top_3_probabilities = [predictions.numpy()[0][i] * 100 for i in top_3_indices]
23
-
24
- output_string = "\n".join([f"{label}: {probability:.2f}%" for label, probability in zip(top_3_labels, top_3_probabilities)])
25
-
26
- return image_resized, output_string
 
 
 
 
 
 
 
 
27
 
28
  # Gradio Interface
29
  interface = gr.Interface(
30
  fn=predict_image,
31
- inputs=gr.Image(type="pil"),
32
- outputs=[gr.Image(type="pil", label="Image Output"), gr.Textbox(label="Prediction")],
 
 
33
  title="Animals Classifier",
34
- description="Upload an image of an animal, and the model will predict it.\n\n**Disclaimer:** This model is trained only on specific animal classes (butterfly, cats, cow, dogs, elephant, horse, monkey, sheep, spider, squirrel) and may not accurately predict animals outside these classes."
 
35
  )
36
 
37
  interface.launch(share=True)
 
16
 
17
  predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
18
 
19
+ # Highest prediction
20
+ top_index = np.argmax(predictions.numpy(), axis=1)[0]
21
+ top_label = labels[top_index]
22
+ top_probability = predictions.numpy()[0][top_index] * 100
23
+
24
+ return top_label, top_probability
25
+
26
+ # Example images
27
+ example_images = [
28
+ ["exp_img/cat.jpg"],
29
+ ["exp_img/cow.jpg"],
30
+ ["exp_img/elephant.jpg"],
31
+ ["exp_img/sheep.jpg"],
32
+ ["exp_img/spider.jpg"],
33
+ ["exp_img/squirrel.jpg"]
34
+ ]
35
 
36
  # Gradio Interface
37
  interface = gr.Interface(
38
  fn=predict_image,
39
+ inputs=gr.Image(type="pil", shape=(224, 224)),
40
+ outputs=gr.Label(num_top_classes=1, label="Prediction"),
41
+ examples=example_images,
42
+ example_per_page=6,
43
  title="Animals Classifier",
44
+ description="Upload an image of an animal, and the model will predict it.\n\n**Disclaimer:** This model is trained only on specific animal classes (butterfly, cats, cow, dogs, elephant, horse, monkey, sheep, spider, squirrel) and may not accurately predict animals outside these classes.",
45
+ allow_flagging="never"
46
  )
47
 
48
  interface.launch(share=True)