pytholic commited on
Commit
dbee91e
·
1 Parent(s): 9def6f5
Files changed (1) hide show
  1. app/app.py +8 -2
app/app.py CHANGED
@@ -11,9 +11,10 @@ import matplotlib.pyplot as plt
11
  import numpy as np
12
  import torch
13
  from albumentations.pytorch import ToTensorV2
14
- from model import Classifier
15
  from PIL import Image
16
 
 
 
17
  # Load the model
18
  model = Classifier.load_from_checkpoint("./models/checkpoint.ckpt")
19
  model.eval()
@@ -55,8 +56,13 @@ def predict(image):
55
  # convert to probabilities
56
  probabilities = torch.nn.functional.softmax(output[0])
57
 
 
 
 
58
  # Return the top 3 predictions
59
- return {labels[i]: float(probabilities[i]) for i in range(3)}
 
 
60
  except Exception as e:
61
  print(f"Error predicting image: {e}")
62
  return []
 
11
  import numpy as np
12
  import torch
13
  from albumentations.pytorch import ToTensorV2
 
14
  from PIL import Image
15
 
16
+ from model import Classifier
17
+
18
  # Load the model
19
  model = Classifier.load_from_checkpoint("./models/checkpoint.ckpt")
20
  model.eval()
 
56
  # convert to probabilities
57
  probabilities = torch.nn.functional.softmax(output[0])
58
 
59
+ # get top probabilities
60
+ topk_prob, topk_label = torch.topk(probabilities, 3)
61
+
62
  # Return the top 3 predictions
63
+ return {
64
+ labels[label]: float(prob) for label, prob in zip(topk_label, topk_prob)
65
+ }
66
  except Exception as e:
67
  print(f"Error predicting image: {e}")
68
  return []