Sobit commited on
Commit
2f9a860
Β·
verified Β·
1 Parent(s): 045c942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -32
app.py CHANGED
@@ -3,27 +3,31 @@ import numpy as np
3
  import cv2
4
  from PIL import Image
5
  from io import BytesIO
6
- from ultralytics import YOLO
7
  import os
8
  import tempfile
9
  import base64
10
  import requests
11
  from datetime import datetime
12
  import google.generativeai as genai # Import Gemini API
 
 
13
 
 
 
14
  # Configuring Google Gemini API
15
  GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
16
  genai.configure(api_key=GEMINI_API_KEY)
17
 
18
  # Loading YOLO model for crop disease detection
19
- yolo_model = YOLO("models/best.pt")
20
 
21
  # Initializing conversation history if not set
22
  if "conversation_history" not in st.session_state:
23
  st.session_state.conversation_history = {}
24
 
25
  # Function to preprocess images
26
- def preprocess_image(image, target_size=(224, 224)):
27
  """Resize image for AI models."""
28
  image = Image.fromarray(image)
29
  image = image.resize(target_size)
@@ -81,24 +85,12 @@ def generate_gemini_response(disease_list, user_context="", conversation_history
81
  return f"Error connecting to Gemini API: {str(e)}"
82
 
83
  # Performing inference using YOLO
84
- def inference(image, conf_threshold=0.5):
85
  """Detect crop diseases in the given image with confidence filtering."""
86
- results = yolo_model(image, conf=0.4) # Adjusted confidence threshold for detection
87
- infer = np.zeros(image.shape, dtype=np.uint8)
88
- detected_classes = []
89
- class_names = {}
90
- confidence_scores = []
91
 
92
- for r in results:
93
- infer = r.plot()
94
- class_names = r.names
95
- for i, cls in enumerate(r.boxes.cls.tolist()):
96
- confidence = r.boxes.conf[i].item() # Get confidence score
97
- if confidence >= conf_threshold: # Only consider high-confidence detections
98
- detected_classes.append(cls)
99
- confidence_scores.append(confidence)
100
-
101
- return infer, detected_classes, class_names, confidence_scores
102
 
103
 
104
 
@@ -163,21 +155,13 @@ if uploaded_file:
163
  img = cv2.imdecode(file_bytes, 1)
164
 
165
  # Perform inference
166
- processed_image, detected_classes, class_names, confidence_scores = inference(img)
167
 
168
  # Display processed image with detected diseases
169
- st.image(processed_image, caption="πŸ” Detected Diseases", use_column_width=True)
170
-
171
- if detected_classes:
172
- # Convert detected class indexes to names
173
- detected_disease_names = [
174
- f"{class_names[cls]} ({confidence_scores[i]:.2f})"
175
- for i, cls in enumerate(detected_classes)
176
- ]
177
-
178
- # Show only the most confident detections
179
- if detected_disease_names:
180
- st.write(f"βœ… **High Confidence Diseases Detected:** {', '.join(detected_disease_names)}")
181
 
182
 
183
 
 
3
  import cv2
4
  from PIL import Image
5
  from io import BytesIO
6
+ #from ultralytics import YOLO
7
  import os
8
  import tempfile
9
  import base64
10
  import requests
11
  from datetime import datetime
12
  import google.generativeai as genai # Import Gemini API
13
+ from tensorflow.keras.models import load_model
14
+ from vit import vit_classifier
15
 
16
+ # Load the model
17
+ model = load_model('vit_updated.weights.h5') # Replace with your model file path
18
  # Configuring Google Gemini API
19
  GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
20
  genai.configure(api_key=GEMINI_API_KEY)
21
 
22
  # Loading YOLO model for crop disease detection
23
+ #yolo_model = YOLO("models/best.pt")
24
 
25
  # Initializing conversation history if not set
26
  if "conversation_history" not in st.session_state:
27
  st.session_state.conversation_history = {}
28
 
29
  # Function to preprocess images
30
+ def preprocess_image(image, target_size=(256, 256)):
31
  """Resize image for AI models."""
32
  image = Image.fromarray(image)
33
  image = image.resize(target_size)
 
85
  return f"Error connecting to Gemini API: {str(e)}"
86
 
87
  # Performing inference using YOLO
88
+ def inference(image):
89
  """Detect crop diseases in the given image with confidence filtering."""
90
+ predictions = vit_classifier.predict(image)
91
+ predicted_labels = np.argmax(predictions, axis=1)
 
 
 
92
 
93
+ return predicted_labels
 
 
 
 
 
 
 
 
 
94
 
95
 
96
 
 
155
  img = cv2.imdecode(file_bytes, 1)
156
 
157
  # Perform inference
158
+ predicted_labels = inference(img)
159
 
160
  # Display processed image with detected diseases
161
+ st.image(img, caption="πŸ” Detected Diseases", use_column_width=True)
162
+
163
+
164
+ st.write(f"βœ… **High Confidence Diseases Detected:** {predicted_labels)}")
 
 
 
 
 
 
 
 
165
 
166
 
167