import gradio as gr import torch from torch import nn from torchvision import models, transforms from huggingface_hub import hf_hub_download from PIL import Image import os import logging import requests from io import BytesIO import numpy as np # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Define the number of classes num_classes = 3 # Confidence threshold for reliable predictions CONFIDENCE_THRESHOLD = 0.8 # 80% # Entropy threshold for flat probability distribution (to detect non-maize/rice images) ENTROPY_THRESHOLD = 0.9 # Lower entropy means a more peaked distribution (more confident) # Download model from Hugging Face def download_model(): model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth") return model_path # Load the model from Hugging Face def load_model(model_path): model = models.resnet50(pretrained=False) num_features = model.fc.in_features model.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(num_features, num_classes) # 3 classes ) # Load the checkpoint checkpoint = torch.load(model_path, map_location=torch.device("cpu")) # Adjust for state dict mismatch by renaming keys state_dict = checkpoint['model_state_dict'] new_state_dict = {} for k, v in state_dict.items(): if k == "fc.weight" or k == "fc.bias": new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=False) model.eval() return model # Path to your model model_path = download_model() model = load_model(model_path) # Define the transformation for the input image transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # Function to compute entropy of the probability distribution def compute_entropy(probabilities): probs = probabilities.numpy() # Convert to numpy array # Avoid log(0) by adding a small epsilon probs = np.clip(probs, 1e-10, 1.0) entropy = -np.sum(probs * np.log(probs)) # Normalize entropy by the maximum possible entropy (log(num_classes)) max_entropy = np.log(num_classes) return entropy / max_entropy # Prediction function for an uploaded image def predict_from_image_url(image_url): try: # Download the image from the provided URL response = requests.get(image_url) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") # Convert to RGB (3 channels) # Apply transformations image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224] logger.info(f"Input image tensor shape: {image_tensor.shape}") # Perform prediction with torch.no_grad(): outputs = model(image_tensor) # Shape: [1, 3] logger.info(f"Model output shape: {outputs.shape}") probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities predicted_class = torch.argmax(outputs, dim=1).item() # Define class information class_info = { 0: {"name": "Fall Army Worm", "problem_id": "126", "crop": "maize"}, 1: {"name": "Phosphorus Deficiency", "problem_id": "142", "crop": "maize"}, 2: {"name": "Bacterial Leaf Blight", "problem_id": "203", "crop": "rice"} } # Validate predicted class index if predicted_class not in class_info: logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}") return { "status": "invalid", "predicted_class": None, "problem_id": None, "confidence": None } # Get predicted class info predicted_info = class_info[predicted_class] predicted_name = predicted_info["name"] problem_id = predicted_info["problem_id"] confidence = probabilities[predicted_class].item() # Confidence score for the predicted class # Compute entropy of the probability distribution entropy = compute_entropy(probabilities) logger.info(f"Prediction entropy: {entropy:.4f}, confidence: {confidence:.4f} for image URL: {image_url}") # Check if the image is likely maize or rice based on entropy and confidence # High entropy (flat distribution) suggests the image may not be maize or rice if entropy > ENTROPY_THRESHOLD: logger.warning( f"High entropy ({entropy:.4f} > {ENTROPY_THRESHOLD}) for image URL: {image_url}. " "Image may not be of maize or rice." ) return { "status": "invalid", "predicted_class": predicted_name, "problem_id": problem_id, "confidence": f"{confidence*100:.2f}%" } # Check confidence threshold if confidence < CONFIDENCE_THRESHOLD: logger.warning( f"Low confidence prediction: {predicted_name} with confidence {confidence*100:.2f}% " f"for image URL: {image_url}" ) return { "status": "invalid", "predicted_class": predicted_name, "problem_id": problem_id, "confidence": f"{confidence*100:.2f}%" } # Return successful prediction return { "status": "valid", "predicted_class": predicted_name, "problem_id": problem_id, "confidence": f"{confidence*100:.2f}%" } except Exception as e: logger.error(f"Error processing image URL {image_url}: {str(e)}") return { "status": "invalid", "predicted_class": None, "problem_id": None, "confidence": None } # Gradio interface demo = gr.Interface( fn=predict_from_image_url, inputs="text", outputs="json", title="Crop Anomaly Classification", description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).", ) if __name__ == "__main__": demo.launch()