jays009 commited on
Commit
05f28db
·
verified ·
1 Parent(s): 5aef664

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -20,6 +20,9 @@ num_classes = 3
20
  # Confidence threshold for main model predictions
21
  CONFIDENCE_THRESHOLD = 0.8 # 80%
22
 
 
 
 
23
  # Download model from Hugging Face
24
  def download_model():
25
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
@@ -62,6 +65,13 @@ transform = transforms.Compose([
62
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
63
  ])
64
 
 
 
 
 
 
 
 
65
  # Prediction function for an uploaded image
66
  def predict_from_image_url(image_url):
67
  try:
@@ -78,12 +88,22 @@ def predict_from_image_url(image_url):
78
  with torch.no_grad():
79
  outputs = main_model(image_tensor) # Shape: [1, 3]
80
  logger.info(f"Model output shape: {outputs.shape}")
81
- # Log raw logits
82
  logger.info(f"Raw logits: {outputs[0].numpy()}")
83
  probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
84
- # Log softmax probabilities
85
  logger.info(f"Softmax probabilities: {probabilities.numpy()}")
86
- predicted_class = torch.argmax(outputs, dim=1).item()
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Define class information
89
  class_info = {
 
20
  # Confidence threshold for main model predictions
21
  CONFIDENCE_THRESHOLD = 0.8 # 80%
22
 
23
+ # Threshold for OOD detection (maximum softmax probability)
24
+ OPENMAX_THRESHOLD = 0.9 # If max softmax prob < 0.9, consider it OOD
25
+
26
  # Download model from Hugging Face
27
  def download_model():
28
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
 
65
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
66
  ])
67
 
68
+ # Simplified OpenMax-like OOD detection
69
+ def openmax_ood_detection(probabilities):
70
+ max_prob = torch.max(probabilities).item()
71
+ logger.info(f"Max softmax probability: {max_prob:.4f}")
72
+ # If the maximum probability is below the threshold, consider it OOD
73
+ return max_prob >= OPENMAX_THRESHOLD
74
+
75
  # Prediction function for an uploaded image
76
  def predict_from_image_url(image_url):
77
  try:
 
88
  with torch.no_grad():
89
  outputs = main_model(image_tensor) # Shape: [1, 3]
90
  logger.info(f"Model output shape: {outputs.shape}")
 
91
  logger.info(f"Raw logits: {outputs[0].numpy()}")
92
  probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
 
93
  logger.info(f"Softmax probabilities: {probabilities.numpy()}")
94
+
95
+ # OpenMax-like OOD detection
96
+ if not openmax_ood_detection(probabilities):
97
+ logger.warning(f"Image URL {image_url} detected as out-of-distribution.")
98
+ return {
99
+ "status": "invalid",
100
+ "predicted_class": None,
101
+ "problem_id": None,
102
+ "confidence": None
103
+ }
104
+
105
+ # Proceed with prediction
106
+ predicted_class = torch.argmax(outputs, dim=1).item()
107
 
108
  # Define class information
109
  class_info = {