import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from huggingface_hub import hf_hub_download from PIL import Image import logging import requests from io import BytesIO # Setup logging logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) # Define the number of classes num_classes = 3 # Confidence threshold for main model predictions CONFIDENCE_THRESHOLD = 0.8 # 80% # Energy threshold for OOD detection (to be calibrated) ENERGY_THRESHOLD = -5.0 # Placeholder, will calibrate # Download model from Hugging Face def download_model(): model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth") return model_path # Load the main model from Hugging Face def load_main_model(model_path): model = models.resnet50(pretrained=False) num_features = model.fc.in_features model.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(num_features, num_classes) # 3 classes ) # Load the checkpoint checkpoint = torch.load(model_path, map_location=torch.device("cpu")) # Adjust for state dict mismatch by renaming keys state_dict = checkpoint['model_state_dict'] new_state_dict = {} for k, v in state_dict.items(): if k == "fc.weight" or k == "fc.bias": new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=False) model.eval() return model # Path to your model model_path = download_model() main_model = load_main_model(model_path) # Define the transformation for the input image transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # Compute energy score for OOD detection def compute_energy_score(logits, temperature=1.0): return -temperature * torch.logsumexp(logits / temperature, dim=1).item() # OOD detection using energy score def is_in_distribution(logits): energy = compute_energy_score(logits) logger.info(f"Energy score: {energy:.4f}") # Log for calibration return energy < ENERGY_THRESHOLD # Lower (more negative) energy means ID # Prediction function for an uploaded image def predict_from_image_url(image_url): try: # Download the image from the provided URL response = requests.get(image_url) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") # Convert to RGB (3 channels) # Apply transformations image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224] # Stage 1: OOD Detection using energy score with torch.no_grad(): logits = main_model(image_tensor) # Shape: [1, 3] if not is_in_distribution(logits): logger.warning(f"Image URL {image_url} detected as out-of-distribution.") return { "status": "invalid", "predicted_class": None, "problem_id": None, "confidence": None } # Stage 2: Main Model Prediction with torch.no_grad(): probabilities = torch.softmax(logits, dim=1)[0] # Convert to probabilities predicted_class = torch.argmax(logits, dim=1).item() # Define class information class_info = { 0: {"name": "Fall Army Worm", "problem_id": "126", "crop": "maize"}, 1: {"name": "Phosphorus Deficiency", "problem_id": "142", "crop": "maize"}, 2: {"name": "Bacterial Leaf Blight", "problem_id": "203", "crop": "rice"} } # Validate predicted class index if predicted_class not in class_info: logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}") return { "status": "invalid", "predicted_class": None, "problem_id": None, "confidence": None } # Get predicted class info predicted_info = class_info[predicted_class] predicted_name = predicted_info["name"] problem_id = predicted_info["problem_id"] confidence = probabilities[predicted_class].item() # Check confidence threshold if confidence < CONFIDENCE_THRESHOLD: logger.warning( f"Low confidence prediction: {predicted_name} with confidence {confidence*100:.2f}% " f"for image URL: {image_url}" ) return { "status": "invalid", "predicted_class": predicted_name, "problem_id": problem_id, "confidence": f"{confidence*100:.2f}%" } # Return successful prediction return { "status": "valid", "predicted_class": predicted_name, "problem_id": problem_id, "confidence": f"{confidence*100:.2f}%" } except Exception as e: logger.error(f"Error processing image URL {image_url}: {str(e)}") return { "status": "invalid", "predicted_class": None, "problem_id": None, "confidence": None } # Gradio interface demo = gr.Interface( fn=predict_from_image_url, inputs="text", outputs="json", title="Crop Anomaly Classification", description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).", ) if __name__ == "__main__": demo.launch()