Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import torch | |
| from torch import nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import os | |
| # Define the number of classes | |
| num_classes = 2 | |
| # Define transformation for image processing | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| # Function to load and preprocess image | |
| def load_image_from_path(image_path): | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image file not found at {image_path}") | |
| image = Image.open(image_path) | |
| image = transform(image).unsqueeze(0) # Convert to tensor and add batch dimension | |
| return image | |
| # Load the model (Example: ResNet50) | |
| def load_model(): | |
| model = models.resnet50(pretrained=True) | |
| model.fc = nn.Linear(model.fc.in_features, num_classes) | |
| model.load_state_dict(torch.load("model.pth")) | |
| model.eval() | |
| return model | |
| # Predict from image tensor | |
| def predict(image_tensor): | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| predicted_class = torch.argmax(outputs, dim=1).item() | |
| return predicted_class | |
| # Initialize model | |
| model = load_model() | |
| # Define the Gradio interface function | |
| def predict_from_file(file_path): | |
| try: | |
| # Load image from path | |
| image_tensor = load_image_from_path(file_path) | |
| # Get prediction | |
| predicted_class = predict(image_tensor) | |
| result = {"result": "Fall armyworm" if predicted_class == 0 else "Healthy maize"} | |
| return result | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=predict_from_file, | |
| inputs=gr.Textbox(label="Image Path (Local)"), | |
| outputs=gr.JSON(), | |
| live=True, | |
| title="Maize Anomaly Detection", | |
| description="Send a local file path via POST request to trigger prediction.", | |
| ) | |
| # Launch the Gradio app | |
| iface.launch(share=True, server_name="0.0.0.0", server_port=7860) | |