Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torch import nn | |
from torchvision import models, transforms | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
import os | |
import logging | |
import requests | |
from io import BytesIO | |
import numpy as np | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Define the number of classes | |
num_classes = 3 | |
# Confidence threshold for reliable predictions | |
CONFIDENCE_THRESHOLD = 0.8 # 80% | |
# Entropy threshold for flat probability distribution (to detect non-maize/rice images) | |
ENTROPY_THRESHOLD = 0.9 # Lower entropy means a more peaked distribution (more confident) | |
# Download model from Hugging Face | |
def download_model(): | |
model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth") | |
return model_path | |
# Load the model from Hugging Face | |
def load_model(model_path): | |
model = models.resnet50(pretrained=False) | |
num_features = model.fc.in_features | |
model.fc = nn.Sequential( | |
nn.Dropout(0.5), | |
nn.Linear(num_features, num_classes) # 3 classes | |
) | |
# Load the checkpoint | |
checkpoint = torch.load(model_path, map_location=torch.device("cpu")) | |
# Adjust for state dict mismatch by renaming keys | |
state_dict = checkpoint['model_state_dict'] | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if k == "fc.weight" or k == "fc.bias": | |
new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v | |
else: | |
new_state_dict[k] = v | |
model.load_state_dict(new_state_dict, strict=False) | |
model.eval() | |
return model | |
# Path to your model | |
model_path = download_model() | |
model = load_model(model_path) | |
# Define the transformation for the input image | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
# Function to compute entropy of the probability distribution | |
def compute_entropy(probabilities): | |
probs = probabilities.numpy() # Convert to numpy array | |
# Avoid log(0) by adding a small epsilon | |
probs = np.clip(probs, 1e-10, 1.0) | |
entropy = -np.sum(probs * np.log(probs)) | |
# Normalize entropy by the maximum possible entropy (log(num_classes)) | |
max_entropy = np.log(num_classes) | |
return entropy / max_entropy | |
# Prediction function for an uploaded image | |
def predict_from_image_url(image_url): | |
try: | |
# Download the image from the provided URL | |
response = requests.get(image_url) | |
response.raise_for_status() | |
image = Image.open(BytesIO(response.content)).convert("RGB") # Convert to RGB (3 channels) | |
# Apply transformations | |
image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224] | |
logger.info(f"Input image tensor shape: {image_tensor.shape}") | |
# Perform prediction | |
with torch.no_grad(): | |
outputs = model(image_tensor) # Shape: [1, 3] | |
logger.info(f"Model output shape: {outputs.shape}") | |
probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities | |
predicted_class = torch.argmax(outputs, dim=1).item() | |
# Define class information | |
class_info = { | |
0: {"name": "Fall Army Worm", "problem_id": "126", "crop": "maize"}, | |
1: {"name": "Phosphorus Deficiency", "problem_id": "142", "crop": "maize"}, | |
2: {"name": "Bacterial Leaf Blight", "problem_id": "203", "crop": "rice"} | |
} | |
# Validate predicted class index | |
if predicted_class not in class_info: | |
logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}") | |
return { | |
"status": "invalid", | |
"predicted_class": None, | |
"problem_id": None, | |
"confidence": None | |
} | |
# Get predicted class info | |
predicted_info = class_info[predicted_class] | |
predicted_name = predicted_info["name"] | |
problem_id = predicted_info["problem_id"] | |
confidence = probabilities[predicted_class].item() # Confidence score for the predicted class | |
# Compute entropy of the probability distribution | |
entropy = compute_entropy(probabilities) | |
logger.info(f"Prediction entropy: {entropy:.4f}, confidence: {confidence:.4f} for image URL: {image_url}") | |
# Check if the image is likely maize or rice based on entropy and confidence | |
# High entropy (flat distribution) suggests the image may not be maize or rice | |
if entropy > ENTROPY_THRESHOLD: | |
logger.warning( | |
f"High entropy ({entropy:.4f} > {ENTROPY_THRESHOLD}) for image URL: {image_url}. " | |
"Image may not be of maize or rice." | |
) | |
return { | |
"status": "invalid", | |
"predicted_class": predicted_name, | |
"problem_id": problem_id, | |
"confidence": f"{confidence*100:.2f}%" | |
} | |
# Check confidence threshold | |
if confidence < CONFIDENCE_THRESHOLD: | |
logger.warning( | |
f"Low confidence prediction: {predicted_name} with confidence {confidence*100:.2f}% " | |
f"for image URL: {image_url}" | |
) | |
return { | |
"status": "invalid", | |
"predicted_class": predicted_name, | |
"problem_id": problem_id, | |
"confidence": f"{confidence*100:.2f}%" | |
} | |
# Return successful prediction | |
return { | |
"status": "valid", | |
"predicted_class": predicted_name, | |
"problem_id": problem_id, | |
"confidence": f"{confidence*100:.2f}%" | |
} | |
except Exception as e: | |
logger.error(f"Error processing image URL {image_url}: {str(e)}") | |
return { | |
"status": "invalid", | |
"predicted_class": None, | |
"problem_id": None, | |
"confidence": None | |
} | |
# Gradio interface | |
demo = gr.Interface( | |
fn=predict_from_image_url, | |
inputs="text", | |
outputs="json", | |
title="Crop Anomaly Classification", | |
description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).", | |
) | |
if __name__ == "__main__": | |
demo.launch() |