jays009 commited on
Commit
ed3f92b
·
verified ·
1 Parent(s): 3cb0798

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -58
app.py CHANGED
@@ -4,15 +4,12 @@ import torch.nn as nn
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
- import os
8
  import logging
9
  import requests
10
  from io import BytesIO
11
- import numpy as np
12
- from scipy.spatial.distance import mahalanobis
13
 
14
  # Setup logging
15
- logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
  # Define the number of classes
@@ -21,8 +18,8 @@ num_classes = 3
21
  # Confidence threshold for main model predictions
22
  CONFIDENCE_THRESHOLD = 0.8 # 80%
23
 
24
- # Mahalanobis distance threshold for OOD detection
25
- MAHALANOBIS_THRESHOLD = 400.0 # Calibrate this using a validation set
26
 
27
  # Download model from Hugging Face
28
  def download_model():
@@ -54,13 +51,6 @@ def load_main_model(model_path):
54
  model.eval()
55
  return model
56
 
57
- # Load class statistics for Mahalanobis distance
58
- try:
59
- class_statistics = torch.load("class_statistics.pth", map_location=torch.device("cpu"))
60
- except FileNotFoundError:
61
- logger.error("class_statistics.pth not found. Please ensure the file is in the same directory as app.py.")
62
- raise
63
-
64
  # Path to your model
65
  model_path = download_model()
66
  main_model = load_main_model(model_path)
@@ -73,27 +63,15 @@ transform = transforms.Compose([
73
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
74
  ])
75
 
76
- # Compute Mahalanobis distance for OOD detection
77
- def compute_mahalanobis_distance(features, mean, cov):
78
- # Convert PyTorch tensors to NumPy arrays for scipy
79
- features_np = features
80
- mean_np = mean.cpu().numpy()
81
- cov_np = cov.cpu().numpy()
82
- # Compute the inverse covariance matrix
83
- cov_inv = np.linalg.inv(cov_np + np.eye(cov_np.shape[0]) * 1e-6) # Add small epsilon for numerical stability
84
- return mahalanobis(features_np, mean_np, cov_inv)
85
-
86
- # OOD detection using Mahalanobis distance
87
- def is_in_distribution(features):
88
- distances = []
89
- for label in class_statistics:
90
- mean = class_statistics[label]["mean"]
91
- cov = class_statistics[label]["cov"]
92
- distance = compute_mahalanobis_distance(features, mean, cov)
93
- distances.append(distance)
94
- min_distance = min(distances)
95
- logger.info(f"Minimum Mahalanobis distance: {min_distance:.4f}")
96
- return min_distance < MAHALANOBIS_THRESHOLD
97
 
98
  # Prediction function for an uploaded image
99
  def predict_from_image_url(image_url):
@@ -105,35 +83,23 @@ def predict_from_image_url(image_url):
105
 
106
  # Apply transformations
107
  image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
108
- logger.info(f"Input image tensor shape: {image_tensor.shape}")
109
 
110
- # Extract features from the penultimate layer
111
  with torch.no_grad():
112
- # Temporarily replace the final layer to get features
113
- original_fc = main_model.fc
114
- main_model.fc = nn.Identity()
115
- features = main_model(image_tensor) # Shape: [1, 2048]
116
- main_model.fc = original_fc # Restore the final layer
117
- features = features[0].cpu().numpy() # Convert to numpy for scipy
118
-
119
- # Stage 1: OOD Detection using Mahalanobis distance
120
- if not is_in_distribution(features):
121
- logger.warning(f"Image URL {image_url} detected as out-of-distribution.")
122
- return {
123
- "status": "invalid",
124
- "predicted_class": None,
125
- "problem_id": None,
126
- "confidence": None
127
- }
128
 
129
  # Stage 2: Main Model Prediction
130
  with torch.no_grad():
131
- outputs = main_model(image_tensor) # Shape: [1, 3]
132
- logger.info(f"Model output shape: {outputs.shape}")
133
- logger.info(f"Raw logits: {outputs[0].numpy()}")
134
- probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
135
- logger.info(f"Softmax probabilities: {probabilities.numpy()}")
136
- predicted_class = torch.argmax(outputs, dim=1).item()
137
 
138
  # Define class information
139
  class_info = {
 
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
 
7
  import logging
8
  import requests
9
  from io import BytesIO
 
 
10
 
11
  # Setup logging
12
+ logging.basicConfig(level=logging.WARNING)
13
  logger = logging.getLogger(__name__)
14
 
15
  # Define the number of classes
 
18
  # Confidence threshold for main model predictions
19
  CONFIDENCE_THRESHOLD = 0.8 # 80%
20
 
21
+ # Energy threshold for OOD detection (to be calibrated)
22
+ ENERGY_THRESHOLD = -5.0 # Placeholder, will calibrate
23
 
24
  # Download model from Hugging Face
25
  def download_model():
 
51
  model.eval()
52
  return model
53
 
 
 
 
 
 
 
 
54
  # Path to your model
55
  model_path = download_model()
56
  main_model = load_main_model(model_path)
 
63
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
64
  ])
65
 
66
+ # Compute energy score for OOD detection
67
+ def compute_energy_score(logits, temperature=1.0):
68
+ return -temperature * torch.logsumexp(logits / temperature, dim=1).item()
69
+
70
+ # OOD detection using energy score
71
+ def is_in_distribution(logits):
72
+ energy = compute_energy_score(logits)
73
+ logger.info(f"Energy score: {energy:.4f}") # Log for calibration
74
+ return energy < ENERGY_THRESHOLD # Lower (more negative) energy means ID
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Prediction function for an uploaded image
77
  def predict_from_image_url(image_url):
 
83
 
84
  # Apply transformations
85
  image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
 
86
 
87
+ # Stage 1: OOD Detection using energy score
88
  with torch.no_grad():
89
+ logits = main_model(image_tensor) # Shape: [1, 3]
90
+ if not is_in_distribution(logits):
91
+ logger.warning(f"Image URL {image_url} detected as out-of-distribution.")
92
+ return {
93
+ "status": "invalid",
94
+ "predicted_class": None,
95
+ "problem_id": None,
96
+ "confidence": None
97
+ }
 
 
 
 
 
 
 
98
 
99
  # Stage 2: Main Model Prediction
100
  with torch.no_grad():
101
+ probabilities = torch.softmax(logits, dim=1)[0] # Convert to probabilities
102
+ predicted_class = torch.argmax(logits, dim=1).item()
 
 
 
 
103
 
104
  # Define class information
105
  class_info = {