pytholic commited on
Commit
31104c0
·
1 Parent(s): d2c0b63

app.py fixed

Browse files
Files changed (1) hide show
  1. app/app.py +4 -22
app/app.py CHANGED
@@ -11,9 +11,8 @@ import matplotlib.pyplot as plt
11
  import numpy as np
12
  import torch
13
  from albumentations.pytorch import ToTensorV2
14
- from PIL import Image
15
-
16
  from model import Classifier
 
17
 
18
  # Load the model
19
  model = Classifier.load_from_checkpoint("./models/checkpoint.ckpt")
@@ -44,15 +43,6 @@ def preprocess(image):
44
  return image
45
 
46
 
47
- # Define the sample images
48
- sample_images = {
49
- "dog": "./test_images/dog.jpeg",
50
- "cat": "./test_images/cat.jpeg",
51
- "butterfly": "./test_images/butterfly.jpeg",
52
- "elephant": "./test_images/elephant.jpg",
53
- "horse": "./test_images/horse.jpeg",
54
- }
55
-
56
  # Define the function to make predictions on an image
57
  def predict(image):
58
  try:
@@ -65,16 +55,8 @@ def predict(image):
65
  # convert to probabilities
66
  probabilities = torch.nn.functional.softmax(output[0])
67
 
68
- topk_prob, topk_label = torch.topk(probabilities, 3)
69
-
70
- # convert the predictions to a list
71
- predictions = []
72
- for i in range(topk_prob.size(0)):
73
- prob = topk_prob[i].item()
74
- label = topk_label[i].item()
75
- predictions.append((prob, label))
76
-
77
- return predictions
78
  except Exception as e:
79
  print(f"Error predicting image: {e}")
80
  return []
@@ -93,7 +75,7 @@ def app():
93
  outputs=gr.Label(
94
  num_top_classes=3,
95
  ),
96
- examples=[
97
  "./test_images/dog.jpeg",
98
  "./test_images/cat.jpeg",
99
  "./test_images/butterfly.jpeg",
 
11
  import numpy as np
12
  import torch
13
  from albumentations.pytorch import ToTensorV2
 
 
14
  from model import Classifier
15
+ from PIL import Image
16
 
17
  # Load the model
18
  model = Classifier.load_from_checkpoint("./models/checkpoint.ckpt")
 
43
  return image
44
 
45
 
 
 
 
 
 
 
 
 
 
46
  # Define the function to make predictions on an image
47
  def predict(image):
48
  try:
 
55
  # convert to probabilities
56
  probabilities = torch.nn.functional.softmax(output[0])
57
 
58
+ # Return the top 3 predictions
59
+ return {labels[i]: float(probabilities[i]) for i in range(3)}
 
 
 
 
 
 
 
 
60
  except Exception as e:
61
  print(f"Error predicting image: {e}")
62
  return []
 
75
  outputs=gr.Label(
76
  num_top_classes=3,
77
  ),
78
+ examples=examples=[
79
  "./test_images/dog.jpeg",
80
  "./test_images/cat.jpeg",
81
  "./test_images/butterfly.jpeg",