Sobit commited on
Commit
5610138
Β·
verified Β·
1 Parent(s): 5a9b64e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -159
app.py CHANGED
@@ -2,188 +2,90 @@ import streamlit as st
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
 
 
 
5
  from io import BytesIO
6
- #from ultralytics import YOLO
7
  import os
8
- import tempfile
9
- import base64
10
- import requests
11
- from datetime import datetime
12
- import google.generativeai as genai # Import Gemini API
13
- from tensorflow.keras.models import load_model
14
- from vit import create_vit_classifier
15
- # Load the model
16
- vit_classifier = load_model('models/vit_updated.weights.h5') # Replace with your model file path
17
- # Configuring Google Gemini API
18
- GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
19
- genai.configure(api_key=GEMINI_API_KEY)
20
 
21
- # Loading YOLO model for crop disease detection
22
- #yolo_model = YOLO("models/best.pt")
23
 
24
- # Initializing conversation history if not set
25
- if "conversation_history" not in st.session_state:
26
- st.session_state.conversation_history = {}
27
-
28
- # Function to preprocess images
29
- def preprocess_image(image, target_size=(256, 256)):
30
- """Resize image for AI models."""
31
- image = Image.fromarray(image)
32
- image = image.resize(target_size)
33
- return image
34
 
35
- # Generate response from Gemini AI with history
36
- def generate_gemini_response(disease_list, user_context="", conversation_history=None):
37
- """Generate a structured diagnosis using Gemini API, considering conversation history."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
  model = genai.GenerativeModel("gemini-1.5-pro")
40
-
41
- # Start with detected diseases
42
  prompt = f"""
43
- You are an expert plant pathologist. The detected crop diseases is: {predicted_labels}.
44
 
45
- User's context or question: {user_context if user_context else "Provide a general analysis"}
 
 
 
 
 
 
46
  """
47
-
48
- # Add past conversation history for better continuity
49
- if conversation_history:
50
- history_text = "\n\nPrevious conversation:\n"
51
- for entry in conversation_history:
52
- history_text += f"- User: {entry['question']}\n- AI: {entry['response']}\n"
53
- prompt += history_text
54
-
55
- # Ask Gemini for a structured diagnosis
56
- prompt += """
57
- For each detected disease, provide a structured analysis following this format:
58
- 1. Disease Name: [Name]
59
- - Pathogen: [Causative organism]
60
- - Severity Level: [Based on visual symptoms]
61
- - Key Symptoms:
62
- * [Symptom 1]
63
- * [Symptom 2]
64
- - Economic Impact:
65
- * [Brief description of potential crop losses]
66
- - Treatment Options:
67
- * Immediate actions: [Short-term solutions]
68
- * Long-term management: [Preventive measures]
69
- - Environmental Conditions:
70
- * Favorable conditions for disease development
71
- * Risk factors
72
- 2. Recommendations:
73
- - Immediate Steps:
74
- * [Action items for immediate control]
75
- - Prevention Strategy:
76
- * [Long-term prevention measures]
77
- - Monitoring Protocol:
78
- * [What to watch for]"""
79
-
80
-
81
  response = model.generate_content(prompt)
82
  return response.text if response else "No response from Gemini."
83
  except Exception as e:
84
  return f"Error connecting to Gemini API: {str(e)}"
85
 
86
- # Performing inference using YOLO
87
- def inference(image):
88
- """Detect crop diseases in the given image with confidence filtering."""
89
- predictions = vit_classifier.predict(image)
90
- predicted_labels = np.argmax(predictions, axis=1)
91
-
92
- return predicted_labels
93
-
94
-
95
 
96
-
97
- # Initialize Streamlit UI
98
- st.title("AI-Powered Crop Disease Detection & Diagnosis System")
99
-
100
- # Sidebar settings
101
- with st.sidebar:
102
- st.header("Settings")
103
-
104
- # Fake model selection (Still uses Gemini)
105
- selected_model = st.selectbox("Choose Model", ["Gemini", "GPT-4", "Claude", "Llama 3", "Mistral"], help="This app always uses Gemini.")
106
-
107
- confidence_threshold = st.slider("Detection Confidence Threshold", 0.0, 1.0, 0.4)
108
-
109
-
110
-
111
- if st.button("Clear Conversation History"):
112
- st.session_state.conversation_history = {}
113
- st.success("Conversation history cleared!")
114
-
115
-
116
- # User context input with example prompts
117
- st.subheader("Provide Initial Context or Ask a Question")
118
-
119
- # Generalized example prompts for easier input
120
- example_prompts = {
121
- "Select an example...": "",
122
- "General Plant Health Issue": "My plant leaves are wilting and turning yellow. Is this a disease or a nutrient deficiency?",
123
- "Leaf Spots and Discoloration": "I see dark spots on my crop leaves. Could this be a fungal or bacterial infection?",
124
- "Leaves Drying or Curling": "The leaves on my plants are curling and drying up. What could be causing this?",
125
- "Pest or Disease?": "I noticed tiny insects on my plants along with some leaf damage. Could this be a pest problem or a disease?",
126
- "Overwatering or Root Rot?": "My plant leaves are turning brown and mushy. Is this due to overwatering or a root infection?",
127
- "Poor Crop Growth": "My crops are growing very slowly and seem weak. Could this be due to soil problems or disease?",
128
- "Weather and Disease Connection": "It has been raining a lot, and now my plants have mold. Could the weather be causing a fungal disease?",
129
- "Regional Disease Concern": "I'm in a humid area and my crops often get infected. What are common diseases for this climate?",
130
- }
131
-
132
- # Dropdown menu for selecting an example
133
- selected_example = st.selectbox("Choose an example to auto-fill:", list(example_prompts.keys()))
134
-
135
- # Auto-fill the text area when an example is selected
136
- user_context = st.text_area(
137
- "Enter details, symptoms, or a question about your plant condition.",
138
- value=example_prompts[selected_example] if selected_example != "Select an example..." else "",
139
- placeholder="Example: My plant leaves are turning yellow and wilting. Is this a disease or a nutrient issue?"
140
- )
141
-
142
- # Upload an image
143
  uploaded_file = st.file_uploader("πŸ“€ Upload a plant image", type=["jpg", "jpeg", "png"])
144
 
145
  if uploaded_file:
146
- file_id = uploaded_file.name
147
-
148
- # Initialize conversation history for this image if not set
149
- if file_id not in st.session_state.conversation_history:
150
- st.session_state.conversation_history[file_id] = []
151
-
152
- # Convert file to image
153
- file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
154
- img = cv2.imdecode(file_bytes, 1)
155
-
156
- # Perform inference
157
- predicted_labels = inference(img)
158
-
159
- # Display processed image with detected diseases
160
- st.image(img, caption="πŸ” Detected Diseases", use_column_width=True)
161
-
162
-
163
- st.write(f"βœ… **High Confidence Diseases Detected:** {predicted_labels}")
164
 
165
-
166
-
167
- # AI-generated diagnosis from Gemini
 
 
 
 
 
 
168
  st.subheader("πŸ“‹ AI Diagnosis")
169
- with st.spinner("Generating diagnosis... πŸ”„"):
170
- diagnosis = generate_gemini_response(detected_disease_names, user_context, st.session_state.conversation_history[file_id])
171
-
172
- # Save response to history
173
- st.session_state.conversation_history[file_id].append({"question": user_context, "response": diagnosis})
174
-
175
- # Display the diagnosis
176
  st.write(diagnosis)
177
 
178
- # Show past conversation history
179
- if st.session_state.conversation_history[file_id]:
180
- st.subheader("πŸ—‚οΈ Conversation History")
181
- for i, entry in enumerate(st.session_state.conversation_history[file_id]):
182
- with st.expander(f"Q{i+1}: {entry['question'][:50]}..."):
183
- st.write("**User:**", entry["question"])
184
- st.write("**AI:**", entry["response"])
185
-
186
-
187
  # Instructions for users
188
  st.markdown("""
189
  ---
 
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms as transforms
8
  from io import BytesIO
 
9
  import os
10
+ import google.generativeai as genai
11
+ from vit_model import VisionTransformer
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Load class names (should match training labels)
14
+ CLASS_NAMES = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy'] # Update with actual class names
15
 
16
+ # Configure Google Gemini API
17
+ GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
18
+ genai.configure(api_key=GEMINI_API_KEY)
 
 
 
 
 
 
 
19
 
20
+ # Load the model
21
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model = VisionTransformer(img_size=128, patch_size=8, num_classes=len(CLASS_NAMES), embed_dim=768, depth=8, num_heads=12, mlp_dim=2048, dropout=0.1)
23
+ model.load_state_dict(torch.load("custom_vit.pth", map_location=DEVICE))
24
+ model.to(DEVICE)
25
+ model.eval()
26
+
27
+ # Function to preprocess images for ViT
28
+ def preprocess_image(image, target_size=(128, 128)):
29
+ transform = transforms.Compose([
30
+ transforms.Resize(target_size),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.5], [0.5]) # Normalize to match training
33
+ ])
34
+ return transform(image).unsqueeze(0).to(DEVICE)
35
+
36
+ # Function for inference using ViT
37
+ def vit_inference(image):
38
+ """Predicts crop disease using the custom ViT model."""
39
+ input_tensor = preprocess_image(image)
40
+ with torch.no_grad():
41
+ output = model(input_tensor)
42
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
43
+ predicted_class = torch.argmax(probabilities).item()
44
+ return CLASS_NAMES[predicted_class], probabilities[predicted_class].item()
45
+
46
+ # Generate response from Gemini AI
47
+ def generate_gemini_response(disease_name):
48
+ """Generate a structured diagnosis using Gemini API."""
49
  try:
50
  model = genai.GenerativeModel("gemini-1.5-pro")
 
 
51
  prompt = f"""
52
+ You are an expert plant pathologist. The detected crop disease is: {disease_name}.
53
 
54
+ Provide a structured analysis including:
55
+ - Pathogen details
56
+ - Severity level
57
+ - Symptoms
58
+ - Economic impact
59
+ - Treatment options (short-term and long-term)
60
+ - Prevention strategies
61
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  response = model.generate_content(prompt)
63
  return response.text if response else "No response from Gemini."
64
  except Exception as e:
65
  return f"Error connecting to Gemini API: {str(e)}"
66
 
67
+ # Initialize Streamlit app
68
+ st.title("🌱 AI-Powered Crop Disease Detection")
 
 
 
 
 
 
 
69
 
70
+ # Upload image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  uploaded_file = st.file_uploader("πŸ“€ Upload a plant image", type=["jpg", "jpeg", "png"])
72
 
73
  if uploaded_file:
74
+ image = Image.open(uploaded_file).convert("RGB")
75
+ st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Run ViT inference
78
+ with st.spinner("Analyzing with Vision Transformer... πŸ”"):
79
+ predicted_class, confidence = vit_inference(image)
80
+
81
+ st.write(f"βœ… **Detected Disease:** {predicted_class} (Confidence: {confidence:.2f})")
82
+
83
+ # Connect to Gemini AI for diagnosis
84
+ with st.spinner("Generating diagnosis with Gemini AI... πŸ’‘"):
85
+ diagnosis = generate_gemini_response(predicted_class)
86
  st.subheader("πŸ“‹ AI Diagnosis")
 
 
 
 
 
 
 
87
  st.write(diagnosis)
88
 
 
 
 
 
 
 
 
 
 
89
  # Instructions for users
90
  st.markdown("""
91
  ---