Spaces:
Sleeping
Sleeping
File size: 7,050 Bytes
4f7c2c3 9487094 4f7c2c3 39dba98 4f7c2c3 e4b1309 9487094 4f7c2c3 fcc2886 694c989 4f7c2c3 5aef664 e4b1309 9487094 05f28db 4f7c2c3 39dba98 4f7c2c3 5aef664 811ddcb 694c989 811ddcb 39dba98 811ddcb 5a8efa9 811ddcb aff9d06 9487094 e36f4e9 9487094 e97dbab 4f7c2c3 5aef664 4f7c2c3 9487094 d2b2f0e 9487094 d2b2f0e 9487094 fcc2886 9487094 05f28db 4f7c2c3 aef42cb 4f7c2c3 aef42cb 4f7c2c3 9487094 4f7c2c3 9487094 d2b2f0e 9487094 05f28db 9487094 4f7c2c3 694c989 e4b1309 694c989 d2df5a9 e4b1309 d2df5a9 694c989 5aef664 e4b1309 694c989 d2df5a9 e4b1309 694c989 e4b1309 d2df5a9 694c989 e4b1309 694c989 e4b1309 694c989 4f7c2c3 694c989 e4b1309 4f7c2c3 aef42cb 4f7c2c3 200f853 4f7c2c3 39dba98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import os
import logging
import requests
from io import BytesIO
import numpy as np
from scipy.spatial.distance import mahalanobis
# Setup logging (reduced to WARNING level)
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
# Define the number of classes
num_classes = 3
# Confidence threshold for main model predictions
CONFIDENCE_THRESHOLD = 0.8 # 80%
# Mahalanobis distance threshold for OOD detection
MAHALANOBIS_THRESHOLD = 100.0 # Calibrate this using a validation set
# Download model from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
return model_path
# Load the main model from Hugging Face
def load_main_model(model_path):
model = models.resnet50(pretrained=False)
num_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(num_features, num_classes) # 3 classes
)
# Load the checkpoint
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
# Adjust for state dict mismatch by renaming keys
state_dict = checkpoint['model_state_dict']
new_state_dict = {}
for k, v in state_dict.items():
if k == "fc.weight" or k == "fc.bias":
new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()
return model
# Load class statistics for Mahalanobis distance
try:
class_statistics = torch.load("class_statistics.pth", map_location=torch.device("cpu"))
except FileNotFoundError:
logger.error("class_statistics.pth not found. Please ensure the file is in the same directory as app.py.")
raise
# Path to your model
model_path = download_model()
main_model = load_main_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Compute Mahalanobis distance for OOD detection
def compute_mahalanobis_distance(features, mean, cov):
# Convert PyTorch tensors to NumPy arrays for scipy
features_np = features
mean_np = mean.cpu().numpy()
cov_np = cov.cpu().numpy()
# Compute the inverse covariance matrix
cov_inv = np.linalg.inv(cov_np + np.eye(cov_np.shape[0]) * 1e-6) # Add small epsilon for numerical stability
return mahalanobis(features_np, mean_np, cov_inv)
# OOD detection using Mahalanobis distance
def is_in_distribution(features):
distances = []
for label in class_statistics:
mean = class_statistics[label]["mean"]
cov = class_statistics[label]["cov"]
distance = compute_mahalanobis_distance(features, mean, cov)
distances.append(distance)
min_distance = min(distances)
logger.info(f"Minimum Mahalanobis distance: {min_distance:.4f}") # Keep this log for calibration
return min_distance < MAHALANOBIS_THRESHOLD
# Prediction function for an uploaded image
def predict_from_image_url(image_url):
try:
# Download the image from the provided URL
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB") # Convert to RGB (3 channels)
# Apply transformations
image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
# Extract features from the penultimate layer
with torch.no_grad():
# Temporarily replace the final layer to get features
original_fc = main_model.fc
main_model.fc = nn.Identity()
features = main_model(image_tensor) # Shape: [1, 2048]
main_model.fc = original_fc # Restore the final layer
features = features[0].cpu().numpy() # Convert to numpy for scipy
# Stage 1: OOD Detection using Mahalanobis distance
if not is_in_distribution(features):
logger.warning(f"Image URL {image_url} detected as out-of-distribution.")
return {
"status": "invalid",
"predicted_class": None,
"problem_id": None,
"confidence": None
}
# Stage 2: Main Model Prediction
with torch.no_grad():
outputs = main_model(image_tensor) # Shape: [1, 3]
probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
predicted_class = torch.argmax(outputs, dim=1).item()
# Define class information
class_info = {
0: {"name": "Fall Army Worm", "problem_id": "126", "crop": "maize"},
1: {"name": "Phosphorus Deficiency", "problem_id": "142", "crop": "maize"},
2: {"name": "Bacterial Leaf Blight", "problem_id": "203", "crop": "rice"}
}
# Validate predicted class index
if predicted_class not in class_info:
logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}")
return {
"status": "invalid",
"predicted_class": None,
"problem_id": None,
"confidence": None
}
# Get predicted class info
predicted_info = class_info[predicted_class]
predicted_name = predicted_info["name"]
problem_id = predicted_info["problem_id"]
confidence = probabilities[predicted_class].item()
# Check confidence threshold
if confidence < CONFIDENCE_THRESHOLD:
logger.warning(
f"Low confidence prediction: {predicted_name} with confidence {confidence*100:.2f}% "
f"for image URL: {image_url}"
)
return {
"status": "invalid",
"predicted_class": predicted_name,
"problem_id": problem_id,
"confidence": f"{confidence*100:.2f}%"
}
# Return successful prediction
return {
"status": "valid",
"predicted_class": predicted_name,
"problem_id": problem_id,
"confidence": f"{confidence*100:.2f}%"
}
except Exception as e:
logger.error(f"Error processing image URL {image_url}: {str(e)}")
return {
"status": "invalid",
"predicted_class": None,
"problem_id": None,
"confidence": None
}
# Gradio interface
demo = gr.Interface(
fn=predict_from_image_url,
inputs="text",
outputs="json",
title="Crop Anomaly Classification",
description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).",
)
if __name__ == "__main__":
demo.launch() |