Crop_anomaly_id / app.py
jays009's picture
Update app.py
fcc2886 verified
raw
history blame
7.05 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import os
import logging
import requests
from io import BytesIO
import numpy as np
from scipy.spatial.distance import mahalanobis
# Setup logging (reduced to WARNING level)
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
# Define the number of classes
num_classes = 3
# Confidence threshold for main model predictions
CONFIDENCE_THRESHOLD = 0.8 # 80%
# Mahalanobis distance threshold for OOD detection
MAHALANOBIS_THRESHOLD = 100.0 # Calibrate this using a validation set
# Download model from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
return model_path
# Load the main model from Hugging Face
def load_main_model(model_path):
model = models.resnet50(pretrained=False)
num_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(num_features, num_classes) # 3 classes
)
# Load the checkpoint
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
# Adjust for state dict mismatch by renaming keys
state_dict = checkpoint['model_state_dict']
new_state_dict = {}
for k, v in state_dict.items():
if k == "fc.weight" or k == "fc.bias":
new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()
return model
# Load class statistics for Mahalanobis distance
try:
class_statistics = torch.load("class_statistics.pth", map_location=torch.device("cpu"))
except FileNotFoundError:
logger.error("class_statistics.pth not found. Please ensure the file is in the same directory as app.py.")
raise
# Path to your model
model_path = download_model()
main_model = load_main_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Compute Mahalanobis distance for OOD detection
def compute_mahalanobis_distance(features, mean, cov):
# Convert PyTorch tensors to NumPy arrays for scipy
features_np = features
mean_np = mean.cpu().numpy()
cov_np = cov.cpu().numpy()
# Compute the inverse covariance matrix
cov_inv = np.linalg.inv(cov_np + np.eye(cov_np.shape[0]) * 1e-6) # Add small epsilon for numerical stability
return mahalanobis(features_np, mean_np, cov_inv)
# OOD detection using Mahalanobis distance
def is_in_distribution(features):
distances = []
for label in class_statistics:
mean = class_statistics[label]["mean"]
cov = class_statistics[label]["cov"]
distance = compute_mahalanobis_distance(features, mean, cov)
distances.append(distance)
min_distance = min(distances)
logger.info(f"Minimum Mahalanobis distance: {min_distance:.4f}") # Keep this log for calibration
return min_distance < MAHALANOBIS_THRESHOLD
# Prediction function for an uploaded image
def predict_from_image_url(image_url):
try:
# Download the image from the provided URL
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB") # Convert to RGB (3 channels)
# Apply transformations
image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
# Extract features from the penultimate layer
with torch.no_grad():
# Temporarily replace the final layer to get features
original_fc = main_model.fc
main_model.fc = nn.Identity()
features = main_model(image_tensor) # Shape: [1, 2048]
main_model.fc = original_fc # Restore the final layer
features = features[0].cpu().numpy() # Convert to numpy for scipy
# Stage 1: OOD Detection using Mahalanobis distance
if not is_in_distribution(features):
logger.warning(f"Image URL {image_url} detected as out-of-distribution.")
return {
"status": "invalid",
"predicted_class": None,
"problem_id": None,
"confidence": None
}
# Stage 2: Main Model Prediction
with torch.no_grad():
outputs = main_model(image_tensor) # Shape: [1, 3]
probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
predicted_class = torch.argmax(outputs, dim=1).item()
# Define class information
class_info = {
0: {"name": "Fall Army Worm", "problem_id": "126", "crop": "maize"},
1: {"name": "Phosphorus Deficiency", "problem_id": "142", "crop": "maize"},
2: {"name": "Bacterial Leaf Blight", "problem_id": "203", "crop": "rice"}
}
# Validate predicted class index
if predicted_class not in class_info:
logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}")
return {
"status": "invalid",
"predicted_class": None,
"problem_id": None,
"confidence": None
}
# Get predicted class info
predicted_info = class_info[predicted_class]
predicted_name = predicted_info["name"]
problem_id = predicted_info["problem_id"]
confidence = probabilities[predicted_class].item()
# Check confidence threshold
if confidence < CONFIDENCE_THRESHOLD:
logger.warning(
f"Low confidence prediction: {predicted_name} with confidence {confidence*100:.2f}% "
f"for image URL: {image_url}"
)
return {
"status": "invalid",
"predicted_class": predicted_name,
"problem_id": problem_id,
"confidence": f"{confidence*100:.2f}%"
}
# Return successful prediction
return {
"status": "valid",
"predicted_class": predicted_name,
"problem_id": problem_id,
"confidence": f"{confidence*100:.2f}%"
}
except Exception as e:
logger.error(f"Error processing image URL {image_url}: {str(e)}")
return {
"status": "invalid",
"predicted_class": None,
"problem_id": None,
"confidence": None
}
# Gradio interface
demo = gr.Interface(
fn=predict_from_image_url,
inputs="text",
outputs="json",
title="Crop Anomaly Classification",
description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).",
)
if __name__ == "__main__":
demo.launch()