Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -20,6 +20,9 @@ num_classes = 3
|
|
20 |
# Confidence threshold for main model predictions
|
21 |
CONFIDENCE_THRESHOLD = 0.8 # 80%
|
22 |
|
|
|
|
|
|
|
23 |
# Download model from Hugging Face
|
24 |
def download_model():
|
25 |
model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
|
@@ -62,6 +65,13 @@ transform = transforms.Compose([
|
|
62 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
63 |
])
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
# Prediction function for an uploaded image
|
66 |
def predict_from_image_url(image_url):
|
67 |
try:
|
@@ -78,12 +88,22 @@ def predict_from_image_url(image_url):
|
|
78 |
with torch.no_grad():
|
79 |
outputs = main_model(image_tensor) # Shape: [1, 3]
|
80 |
logger.info(f"Model output shape: {outputs.shape}")
|
81 |
-
# Log raw logits
|
82 |
logger.info(f"Raw logits: {outputs[0].numpy()}")
|
83 |
probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
|
84 |
-
# Log softmax probabilities
|
85 |
logger.info(f"Softmax probabilities: {probabilities.numpy()}")
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
# Define class information
|
89 |
class_info = {
|
|
|
20 |
# Confidence threshold for main model predictions
|
21 |
CONFIDENCE_THRESHOLD = 0.8 # 80%
|
22 |
|
23 |
+
# Threshold for OOD detection (maximum softmax probability)
|
24 |
+
OPENMAX_THRESHOLD = 0.9 # If max softmax prob < 0.9, consider it OOD
|
25 |
+
|
26 |
# Download model from Hugging Face
|
27 |
def download_model():
|
28 |
model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
|
|
|
65 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
66 |
])
|
67 |
|
68 |
+
# Simplified OpenMax-like OOD detection
|
69 |
+
def openmax_ood_detection(probabilities):
|
70 |
+
max_prob = torch.max(probabilities).item()
|
71 |
+
logger.info(f"Max softmax probability: {max_prob:.4f}")
|
72 |
+
# If the maximum probability is below the threshold, consider it OOD
|
73 |
+
return max_prob >= OPENMAX_THRESHOLD
|
74 |
+
|
75 |
# Prediction function for an uploaded image
|
76 |
def predict_from_image_url(image_url):
|
77 |
try:
|
|
|
88 |
with torch.no_grad():
|
89 |
outputs = main_model(image_tensor) # Shape: [1, 3]
|
90 |
logger.info(f"Model output shape: {outputs.shape}")
|
|
|
91 |
logger.info(f"Raw logits: {outputs[0].numpy()}")
|
92 |
probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
|
|
|
93 |
logger.info(f"Softmax probabilities: {probabilities.numpy()}")
|
94 |
+
|
95 |
+
# OpenMax-like OOD detection
|
96 |
+
if not openmax_ood_detection(probabilities):
|
97 |
+
logger.warning(f"Image URL {image_url} detected as out-of-distribution.")
|
98 |
+
return {
|
99 |
+
"status": "invalid",
|
100 |
+
"predicted_class": None,
|
101 |
+
"problem_id": None,
|
102 |
+
"confidence": None
|
103 |
+
}
|
104 |
+
|
105 |
+
# Proceed with prediction
|
106 |
+
predicted_class = torch.argmax(outputs, dim=1).item()
|
107 |
|
108 |
# Define class information
|
109 |
class_info = {
|