File size: 4,303 Bytes
4f7c2c3
 
 
 
 
 
39dba98
4f7c2c3
 
 
 
 
 
 
 
 
 
 
 
39dba98
 
4f7c2c3
 
5a8efa9
811ddcb
 
 
 
 
 
39dba98
811ddcb
 
 
 
 
 
 
 
 
 
 
5a8efa9
811ddcb
 
 
aff9d06
e97dbab
4f7c2c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aef42cb
4f7c2c3
 
aef42cb
 
4f7c2c3
 
 
aef42cb
 
 
4f7c2c3
 
 
39dba98
aef42cb
 
 
 
 
 
 
 
39dba98
aef42cb
 
 
 
 
 
 
 
39dba98
aef42cb
 
 
 
 
 
 
 
39dba98
 
4f7c2c3
 
 
 
aef42cb
4f7c2c3
 
 
 
200f853
4f7c2c3
 
 
 
39dba98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import os
import logging
import requests
from io import BytesIO

# Setup logging
logging.basicConfig(level=logging.INFO)

# Define the number of classes
num_classes = 3

# Download model from Hugging Face
def download_model():
    model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
    return model_path

# Load the model from Hugging Face
def load_model(model_path):
    model = models.resnet50(pretrained=False)
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_features, 3)  # 3 classes
    )

    # Load the checkpoint
    checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
    
    # Adjust for state dict mismatch by renaming keys
    state_dict = checkpoint['model_state_dict']
    new_state_dict = {}
    for k, v in state_dict.items():
        if k == "fc.weight" or k == "fc.bias":
            new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict, strict=False)
    model.eval()
    return model

# Path to your model
model_path = download_model()
model = load_model(model_path)

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Prediction function for an uploaded image
def predict_from_image_url(image_url):
    try:
        # Download the image from the provided URL
        response = requests.get(image_url)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content)).convert("RGB")  # Convert to RGB (3 channels)

        # Apply transformations
        image_tensor = transform(image).unsqueeze(0)  # Shape: [1, 3, 224, 224]
        print(f"Input image tensor shape: {image_tensor.shape}")  # Debug: Should be [1, 3, 224, 224]

        # Perform prediction
        with torch.no_grad():
            outputs = model(image_tensor)  # Shape: [1, 3]
            print(f"Model output shape: {outputs.shape}")  # Debug: Should be [1, 3]
            probabilities = torch.softmax(outputs, dim=1)[0]  # Convert to probabilities
            predicted_class = torch.argmax(outputs, dim=1).item()

        # Interpret the result
        if predicted_class == 0:
            return {
                "result": "The photo is of Fall Army Worm with problem ID 126.",
                "probabilities": {
                    "Fall Army Worm": f"{probabilities[0]*100:.2f}%",
                    "Phosphorus Deficiency": f"{probabilities[1]*100:.2f}%",
                    "Bacterial Leaf Blight": f"{probabilities[2]*100:.2f}%"
                }
            }
        elif predicted_class == 1:
            return {
                "result": "The photo shows symptoms of Phosphorus Deficiency with Problem ID 142.",
                "probabilities": {
                    "Fall Army Worm": f"{probabilities[0]*100:.2f}%",
                    "Phosphorus Deficiency": f"{probabilities[1]*100:.2f}%",
                    "Bacterial Leaf Blight": f"{probabilities[2]*100:.2f}%"
                }
            }
        elif predicted_class == 2:
            return {
                "result": "The photo shows symptoms of Bacterial Leaf Blight with Problem ID 203.",
                "probabilities": {
                    "Fall Army Worm": f"{probabilities[0]*100:.2f}%",
                    "Phosphorus Deficiency": f"{probabilities[1]*100:.2f}%",
                    "Bacterial Leaf Blight": f"{probabilities[2]*100:.2f}%"
                }
            }
        else:
            return {"error": "Unexpected class prediction."}

    except Exception as e:
        return {"error": str(e)}

# Gradio interface
demo = gr.Interface(
    fn=predict_from_image_url,
    inputs="text",
    outputs="json",
    title="Crop Anomaly Classification",
    description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).",
)

if __name__ == "__main__":
    demo.launch()