jays009 commited on
Commit
9487094
·
verified ·
1 Parent(s): ca52368

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -20
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from torch import nn
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
@@ -9,6 +9,7 @@ import logging
9
  import requests
10
  from io import BytesIO
11
  import numpy as np
 
12
 
13
  # Setup logging
14
  logging.basicConfig(level=logging.INFO)
@@ -20,8 +21,8 @@ num_classes = 3
20
  # Confidence threshold for main model predictions
21
  CONFIDENCE_THRESHOLD = 0.8 # 80%
22
 
23
- # Threshold for OOD detection (maximum softmax probability)
24
- OPENMAX_THRESHOLD = 0.9 # If max softmax prob < 0.9, consider it OOD
25
 
26
  # Download model from Hugging Face
27
  def download_model():
@@ -53,6 +54,13 @@ def load_main_model(model_path):
53
  model.eval()
54
  return model
55
 
 
 
 
 
 
 
 
56
  # Path to your model
57
  model_path = download_model()
58
  main_model = load_main_model(model_path)
@@ -65,12 +73,23 @@ transform = transforms.Compose([
65
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
66
  ])
67
 
68
- # Simplified OpenMax-like OOD detection
69
- def openmax_ood_detection(probabilities):
70
- max_prob = torch.max(probabilities).item()
71
- logger.info(f"Max softmax probability: {max_prob:.4f}")
72
- # If the maximum probability is below the threshold, consider it OOD
73
- return max_prob >= OPENMAX_THRESHOLD
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Prediction function for an uploaded image
76
  def predict_from_image_url(image_url):
@@ -84,16 +103,17 @@ def predict_from_image_url(image_url):
84
  image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
85
  logger.info(f"Input image tensor shape: {image_tensor.shape}")
86
 
87
- # Perform prediction
88
  with torch.no_grad():
89
- outputs = main_model(image_tensor) # Shape: [1, 3]
90
- logger.info(f"Model output shape: {outputs.shape}")
91
- logger.info(f"Raw logits: {outputs[0].numpy()}")
92
- probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
93
- logger.info(f"Softmax probabilities: {probabilities.numpy()}")
94
-
95
- # OpenMax-like OOD detection
96
- if not openmax_ood_detection(probabilities):
 
97
  logger.warning(f"Image URL {image_url} detected as out-of-distribution.")
98
  return {
99
  "status": "invalid",
@@ -102,8 +122,14 @@ def predict_from_image_url(image_url):
102
  "confidence": None
103
  }
104
 
105
- # Proceed with prediction
106
- predicted_class = torch.argmax(outputs, dim=1).item()
 
 
 
 
 
 
107
 
108
  # Define class information
109
  class_info = {
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn as nn
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
 
9
  import requests
10
  from io import BytesIO
11
  import numpy as np
12
+ from scipy.spatial.distance import mahalanobis
13
 
14
  # Setup logging
15
  logging.basicConfig(level=logging.INFO)
 
21
  # Confidence threshold for main model predictions
22
  CONFIDENCE_THRESHOLD = 0.8 # 80%
23
 
24
+ # Mahalanobis distance threshold for OOD detection
25
+ MAHALANOBIS_THRESHOLD = 100.0 # Calibrate this using a validation set
26
 
27
  # Download model from Hugging Face
28
  def download_model():
 
54
  model.eval()
55
  return model
56
 
57
+ # Load class statistics for Mahalanobis distance
58
+ try:
59
+ class_statistics = torch.load("class_statistics.pth", map_location=torch.device("cpu"))
60
+ except FileNotFoundError:
61
+ logger.error("class_statistics.pth not found. Please run the statistics computation script first.")
62
+ raise
63
+
64
  # Path to your model
65
  model_path = download_model()
66
  main_model = load_main_model(model_path)
 
73
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
74
  ])
75
 
76
+ # Compute Mahalanobis distance for OOD detection
77
+ def compute_mahalanobis_distance(features, mean, cov):
78
+ # Compute the inverse covariance matrix
79
+ cov_inv = np.linalg.inv(cov + np.eye(cov.shape[0]) * 1e-6) # Add small epsilon for numerical stability
80
+ return mahalanobis(features, mean, cov_inv)
81
+
82
+ # OOD detection using Mahalanobis distance
83
+ def is_in_distribution(features):
84
+ distances = []
85
+ for label in class_statistics:
86
+ mean = class_statistics[label]["mean"]
87
+ cov = class_statistics[label]["cov"]
88
+ distance = compute_mahalanobis_distance(features, mean, cov)
89
+ distances.append(distance)
90
+ min_distance = min(distances)
91
+ logger.info(f"Minimum Mahalanobis distance: {min_distance:.4f}")
92
+ return min_distance < MAHALANOBIS_THRESHOLD
93
 
94
  # Prediction function for an uploaded image
95
  def predict_from_image_url(image_url):
 
103
  image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
104
  logger.info(f"Input image tensor shape: {image_tensor.shape}")
105
 
106
+ # Extract features from the penultimate layer
107
  with torch.no_grad():
108
+ # Temporarily replace the final layer to get features
109
+ original_fc = main_model.fc
110
+ main_model.fc = nn.Identity()
111
+ features = main_model(image_tensor) # Shape: [1, 2048]
112
+ main_model.fc = original_fc # Restore the final layer
113
+ features = features[0].numpy() # Convert to numpy
114
+
115
+ # Stage 1: OOD Detection using Mahalanobis distance
116
+ if not is_in_distribution(features):
117
  logger.warning(f"Image URL {image_url} detected as out-of-distribution.")
118
  return {
119
  "status": "invalid",
 
122
  "confidence": None
123
  }
124
 
125
+ # Stage 2: Main Model Prediction
126
+ with torch.no_grad():
127
+ outputs = main_model(image_tensor) # Shape: [1, 3]
128
+ logger.info(f"Model output shape: {outputs.shape}")
129
+ logger.info(f"Raw logits: {outputs[0].numpy()}")
130
+ probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
131
+ logger.info(f"Softmax probabilities: {probabilities.numpy()}")
132
+ predicted_class = torch.argmax(outputs, dim=1).item()
133
 
134
  # Define class information
135
  class_info = {