Crop_anomaly_id / app.py
jays009's picture
Update app.py
97f39b9 verified
raw
history blame
3.42 kB
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import logging
import requests
from io import BytesIO
# Setup logging
logging.basicConfig(level=logging.INFO)
# Define the number of classes
num_classes = 3
# Download model from Hugging Face
def download_model():
try:
model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
logging.info("Model downloaded successfully.")
return model_path
except Exception as e:
logging.error(f"Failed to download model: {e}")
raise
# Load the model from Hugging Face
def load_model(model_path):
model = models.resnet50(pretrained=False)
num_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(num_features, 3) # 3 classes
)
# Load the checkpoint
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
# Adjust for state dict mismatch by renaming keys
state_dict = checkpoint['model_state_dict']
new_state_dict = {}
for k, v in state_dict.items():
if k == "fc.weight" or k == "fc.bias":
new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()
return model
except Exception as e:
logging.error(f"Failed to load model: {e}")
raise
# Path to your model
model_path = download_model()
model = load_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Prediction function for an uploaded image
def predict_from_image_url(image_url):
try:
# Download the image from the provided URL
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert('RGB')
# Apply transformations
image_tensor = transform(image).unsqueeze(0)
# Perform prediction
with torch.no_grad():
outputs = model(image_tensor)
if outputs.shape[1] != num_classes:
raise ValueError(f"Unexpected number of output classes: {outputs.shape[1]} (expected {num_classes})")
predicted_class = torch.argmax(outputs, dim=1).item()
# Interpret the result
class_map = {
0: "The photo is of Fall Army Worm with problem ID 126.",
1: "The photo shows symptoms of Phosphorus Deficiency with Problem ID 142.",
2: "The photo shows symptoms of Bacterial Leaf Blight with Problem ID 203."
}
return {"result": class_map.get(predicted_class, "Unexpected class prediction.")}
except Exception as e:
logging.error(f"Error during prediction: {e}")
return {"error": str(e)}
# Initialize Gradio interface
demo = gr.Interface(
fn=predict_from_image_url,
inputs="text",
outputs="json",
title="Maize Disease Classification",
description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).",
)
if __name__ == "__main__":
demo.launch()