import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from huggingface_hub import hf_hub_download from PIL import Image import os import logging import requests from io import BytesIO import numpy as np from scipy.spatial.distance import mahalanobis # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Define the number of classes num_classes = 3 # Confidence threshold for main model predictions CONFIDENCE_THRESHOLD = 0.8 # 80% # Mahalanobis distance threshold for OOD detection MAHALANOBIS_THRESHOLD = 400.0 # Calibrate this using a validation set # Download model from Hugging Face def download_model(): model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth") return model_path # Load the main model from Hugging Face def load_main_model(model_path): model = models.resnet50(pretrained=False) num_features = model.fc.in_features model.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(num_features, num_classes) # 3 classes ) # Load the checkpoint checkpoint = torch.load(model_path, map_location=torch.device("cpu")) # Adjust for state dict mismatch by renaming keys state_dict = checkpoint['model_state_dict'] new_state_dict = {} for k, v in state_dict.items(): if k == "fc.weight" or k == "fc.bias": new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=False) model.eval() return model # Load class statistics for Mahalanobis distance try: class_statistics = torch.load("class_statistics.pth", map_location=torch.device("cpu")) except FileNotFoundError: logger.error("class_statistics.pth not found. Please ensure the file is in the same directory as app.py.") raise # Path to your model model_path = download_model() main_model = load_main_model(model_path) # Define the transformation for the input image transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # Compute Mahalanobis distance for OOD detection def compute_mahalanobis_distance(features, mean, cov): # Convert PyTorch tensors to NumPy arrays for scipy features_np = features mean_np = mean.cpu().numpy() cov_np = cov.cpu().numpy() # Compute the inverse covariance matrix cov_inv = np.linalg.inv(cov_np + np.eye(cov_np.shape[0]) * 1e-6) # Add small epsilon for numerical stability return mahalanobis(features_np, mean_np, cov_inv) # OOD detection using Mahalanobis distance def is_in_distribution(features): distances = [] for label in class_statistics: mean = class_statistics[label]["mean"] cov = class_statistics[label]["cov"] distance = compute_mahalanobis_distance(features, mean, cov) distances.append(distance) min_distance = min(distances) logger.info(f"Minimum Mahalanobis distance: {min_distance:.4f}") return min_distance < MAHALANOBIS_THRESHOLD # Prediction function for an uploaded image def predict_from_image_url(image_url): try: # Download the image from the provided URL response = requests.get(image_url) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") # Convert to RGB (3 channels) # Apply transformations image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224] logger.info(f"Input image tensor shape: {image_tensor.shape}") # Extract features from the penultimate layer with torch.no_grad(): # Temporarily replace the final layer to get features original_fc = main_model.fc main_model.fc = nn.Identity() features = main_model(image_tensor) # Shape: [1, 2048] main_model.fc = original_fc # Restore the final layer features = features[0].cpu().numpy() # Convert to numpy for scipy # Stage 1: OOD Detection using Mahalanobis distance if not is_in_distribution(features): logger.warning(f"Image URL {image_url} detected as out-of-distribution.") return { "status": "invalid", "predicted_class": None, "problem_id": None, "confidence": None } # Stage 2: Main Model Prediction with torch.no_grad(): outputs = main_model(image_tensor) # Shape: [1, 3] logger.info(f"Model output shape: {outputs.shape}") logger.info(f"Raw logits: {outputs[0].numpy()}") probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities logger.info(f"Softmax probabilities: {probabilities.numpy()}") predicted_class = torch.argmax(outputs, dim=1).item() # Define class information class_info = { 0: {"name": "Fall Army Worm", "problem_id": "126", "crop": "maize"}, 1: {"name": "Phosphorus Deficiency", "problem_id": "142", "crop": "maize"}, 2: {"name": "Bacterial Leaf Blight", "problem_id": "203", "crop": "rice"} } # Validate predicted class index if predicted_class not in class_info: logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}") return { "status": "invalid", "predicted_class": None, "problem_id": None, "confidence": None } # Get predicted class info predicted_info = class_info[predicted_class] predicted_name = predicted_info["name"] problem_id = predicted_info["problem_id"] confidence = probabilities[predicted_class].item() # Check confidence threshold if confidence < CONFIDENCE_THRESHOLD: logger.warning( f"Low confidence prediction: {predicted_name} with confidence {confidence*100:.2f}% " f"for image URL: {image_url}" ) return { "status": "invalid", "predicted_class": predicted_name, "problem_id": problem_id, "confidence": f"{confidence*100:.2f}%" } # Return successful prediction return { "status": "valid", "predicted_class": predicted_name, "problem_id": problem_id, "confidence": f"{confidence*100:.2f}%" } except Exception as e: logger.error(f"Error processing image URL {image_url}: {str(e)}") return { "status": "invalid", "predicted_class": None, "problem_id": None, "confidence": None } # Gradio interface demo = gr.Interface( fn=predict_from_image_url, inputs="text", outputs="json", title="Crop Anomaly Classification", description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).", ) if __name__ == "__main__": demo.launch()