Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| import os | |
| import logging | |
| import requests | |
| from io import BytesIO | |
| import numpy as np | |
| from scipy.spatial.distance import mahalanobis | |
| # Setup logging (reduced to WARNING level) | |
| logging.basicConfig(level=logging.WARNING) | |
| logger = logging.getLogger(__name__) | |
| # Define the number of classes | |
| num_classes = 3 | |
| # Confidence threshold for main model predictions | |
| CONFIDENCE_THRESHOLD = 0.8 # 80% | |
| # Mahalanobis distance threshold for OOD detection | |
| MAHALANOBIS_THRESHOLD = 100.0 # Calibrate this using a validation set | |
| # Download model from Hugging Face | |
| def download_model(): | |
| model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth") | |
| return model_path | |
| # Load the main model from Hugging Face | |
| def load_main_model(model_path): | |
| model = models.resnet50(pretrained=False) | |
| num_features = model.fc.in_features | |
| model.fc = nn.Sequential( | |
| nn.Dropout(0.5), | |
| nn.Linear(num_features, num_classes) # 3 classes | |
| ) | |
| # Load the checkpoint | |
| checkpoint = torch.load(model_path, map_location=torch.device("cpu")) | |
| # Adjust for state dict mismatch by renaming keys | |
| state_dict = checkpoint['model_state_dict'] | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k == "fc.weight" or k == "fc.bias": | |
| new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v | |
| else: | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict, strict=False) | |
| model.eval() | |
| return model | |
| # Load class statistics for Mahalanobis distance | |
| try: | |
| class_statistics = torch.load("class_statistics.pth", map_location=torch.device("cpu")) | |
| except FileNotFoundError: | |
| logger.error("class_statistics.pth not found. Please ensure the file is in the same directory as app.py.") | |
| raise | |
| # Path to your model | |
| model_path = download_model() | |
| main_model = load_main_model(model_path) | |
| # Define the transformation for the input image | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| # Compute Mahalanobis distance for OOD detection | |
| def compute_mahalanobis_distance(features, mean, cov): | |
| # Convert PyTorch tensors to NumPy arrays for scipy | |
| features_np = features | |
| mean_np = mean.cpu().numpy() | |
| cov_np = cov.cpu().numpy() | |
| # Compute the inverse covariance matrix | |
| cov_inv = np.linalg.inv(cov_np + np.eye(cov_np.shape[0]) * 1e-6) # Add small epsilon for numerical stability | |
| return mahalanobis(features_np, mean_np, cov_inv) | |
| # OOD detection using Mahalanobis distance | |
| def is_in_distribution(features): | |
| distances = [] | |
| for label in class_statistics: | |
| mean = class_statistics[label]["mean"] | |
| cov = class_statistics[label]["cov"] | |
| distance = compute_mahalanobis_distance(features, mean, cov) | |
| distances.append(distance) | |
| min_distance = min(distances) | |
| logger.info(f"Minimum Mahalanobis distance: {min_distance:.4f}") # Keep this log for calibration | |
| return min_distance < MAHALANOBIS_THRESHOLD | |
| # Prediction function for an uploaded image | |
| def predict_from_image_url(image_url): | |
| try: | |
| # Download the image from the provided URL | |
| response = requests.get(image_url) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)).convert("RGB") # Convert to RGB (3 channels) | |
| # Apply transformations | |
| image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224] | |
| # Extract features from the penultimate layer | |
| with torch.no_grad(): | |
| # Temporarily replace the final layer to get features | |
| original_fc = main_model.fc | |
| main_model.fc = nn.Identity() | |
| features = main_model(image_tensor) # Shape: [1, 2048] | |
| main_model.fc = original_fc # Restore the final layer | |
| features = features[0].cpu().numpy() # Convert to numpy for scipy | |
| # Stage 1: OOD Detection using Mahalanobis distance | |
| if not is_in_distribution(features): | |
| logger.warning(f"Image URL {image_url} detected as out-of-distribution.") | |
| return { | |
| "status": "invalid", | |
| "predicted_class": None, | |
| "problem_id": None, | |
| "confidence": None | |
| } | |
| # Stage 2: Main Model Prediction | |
| with torch.no_grad(): | |
| outputs = main_model(image_tensor) # Shape: [1, 3] | |
| probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities | |
| predicted_class = torch.argmax(outputs, dim=1).item() | |
| # Define class information | |
| class_info = { | |
| 0: {"name": "Fall Army Worm", "problem_id": "126", "crop": "maize"}, | |
| 1: {"name": "Phosphorus Deficiency", "problem_id": "142", "crop": "maize"}, | |
| 2: {"name": "Bacterial Leaf Blight", "problem_id": "203", "crop": "rice"} | |
| } | |
| # Validate predicted class index | |
| if predicted_class not in class_info: | |
| logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}") | |
| return { | |
| "status": "invalid", | |
| "predicted_class": None, | |
| "problem_id": None, | |
| "confidence": None | |
| } | |
| # Get predicted class info | |
| predicted_info = class_info[predicted_class] | |
| predicted_name = predicted_info["name"] | |
| problem_id = predicted_info["problem_id"] | |
| confidence = probabilities[predicted_class].item() | |
| # Check confidence threshold | |
| if confidence < CONFIDENCE_THRESHOLD: | |
| logger.warning( | |
| f"Low confidence prediction: {predicted_name} with confidence {confidence*100:.2f}% " | |
| f"for image URL: {image_url}" | |
| ) | |
| return { | |
| "status": "invalid", | |
| "predicted_class": predicted_name, | |
| "problem_id": problem_id, | |
| "confidence": f"{confidence*100:.2f}%" | |
| } | |
| # Return successful prediction | |
| return { | |
| "status": "valid", | |
| "predicted_class": predicted_name, | |
| "problem_id": problem_id, | |
| "confidence": f"{confidence*100:.2f}%" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing image URL {image_url}: {str(e)}") | |
| return { | |
| "status": "invalid", | |
| "predicted_class": None, | |
| "problem_id": None, | |
| "confidence": None | |
| } | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=predict_from_image_url, | |
| inputs="text", | |
| outputs="json", | |
| title="Crop Anomaly Classification", | |
| description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |