Baskar2005 commited on
Commit
fae8423
·
verified ·
1 Parent(s): 7903d08

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +117 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms, models
5
+ from PIL import Image
6
+ import torch.nn.functional as F
7
+ import gradio as gr
8
+
9
+ class TomatoLeafDiseaseDetectionApp:
10
+ def __init__(self):
11
+ self.class_names = [
12
+ 'Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight',
13
+ 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot',
14
+ 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold',
15
+ 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite',
16
+ 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus',
17
+ 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy'
18
+ ]
19
+
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ self.model = self.load_model()
22
+
23
+ def load_model(self):
24
+ """
25
+ Load the trained EfficientNet model with the weights for tomato leaf disease detection.
26
+ """
27
+ # Define the model structure
28
+ base_model = models.efficientnet_b0(weights=None) # No pretrained weights
29
+ base_model.classifier = nn.Identity() # Remove the original classifier
30
+ feature_size = 1280 # EfficientNetB0 output feature size
31
+
32
+ model = nn.Sequential(
33
+ base_model,
34
+ nn.Dropout(0.3),
35
+ nn.Linear(feature_size, len(self.class_names))
36
+ )
37
+
38
+ # Load the model weights
39
+ model_path = "tomato_leaf_disease_model.pth" # Update this path
40
+ model.load_state_dict(torch.load(model_path, map_location=self.device))
41
+ model.to(self.device)
42
+ model.eval() # Set the model to evaluation mode
43
+ return model
44
+
45
+ def predict_disease(self, image_path):
46
+ """
47
+ Predict the tomato leaf disease from the given image.
48
+
49
+ Args:
50
+ image_path (str): Path to the input image.
51
+
52
+ Returns:
53
+ tuple: Predicted disease name and confidence score.
54
+ """
55
+ try:
56
+ # Image preprocessing
57
+ transform = transforms.Compose([
58
+ transforms.Resize((224, 224)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalize for EfficientNet
61
+ ])
62
+ image = Image.open(image_path).convert("RGB")
63
+ input_tensor = transform(image).unsqueeze(0).to(self.device)
64
+
65
+ # Perform prediction
66
+ with torch.no_grad():
67
+ outputs = self.model(input_tensor)
68
+ probabilities = F.softmax(outputs, dim=1)
69
+ predicted_class = probabilities.argmax(1)
70
+ confidence_score = probabilities[0, predicted_class.item()].item()
71
+
72
+ predicted_class_name = self.class_names[predicted_class.item()]
73
+ return predicted_class_name, confidence_score
74
+ except Exception as e:
75
+ return f"Error: {str(e)}", 0.0
76
+
77
+ def gradio_interface(self):
78
+ """
79
+ Launch the Gradio interface for tomato leaf disease detection.
80
+ """
81
+ def classify_image(image_path):
82
+ disease_name, confidence = self.predict_disease(image_path)
83
+ return disease_name, f"Confidence: {confidence:.2f}"
84
+
85
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
86
+ gr.HTML("<center><h1>Tomato Leaf Disease Detection</h1></center>")
87
+
88
+ with gr.Row():
89
+ input_image = gr.Image(type="filepath", label="Upload Leaf Image", source="upload")
90
+ with gr.Column():
91
+ output_label = gr.Label(label="Predicted Disease")
92
+ confidence_text = gr.Textbox(label="Confidence Score")
93
+
94
+ with gr.Row():
95
+ button = gr.Button(value="Detect Disease")
96
+
97
+ button.click(
98
+ classify_image,
99
+ inputs=[input_image],
100
+ outputs=[output_label, confidence_text]
101
+ )
102
+
103
+ gr.Examples(
104
+ examples=[
105
+ "tomato_earlt_blight.jpg", # Replace with your example paths
106
+ "yellow_leaf_curl.jpg",
107
+ ],
108
+ inputs=[input_image],
109
+ outputs=[output_label, confidence_text],
110
+ label="Example Images"
111
+ )
112
+
113
+ demo.launch(debug=True)
114
+
115
+ if __name__ == "__main__":
116
+ app = TomatoLeafDiseaseDetectionApp()
117
+ app.gradio_interface()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio