jays009 commited on
Commit
e4b1309
·
verified ·
1 Parent(s): 694c989

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -25
app.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  import logging
9
  import requests
10
  from io import BytesIO
 
11
 
12
  # Setup logging
13
  logging.basicConfig(level=logging.INFO)
@@ -17,7 +18,10 @@ logger = logging.getLogger(__name__)
17
  num_classes = 3
18
 
19
  # Confidence threshold for reliable predictions
20
- CONFIDENCE_THRESHOLD = 0.5 # 50%
 
 
 
21
 
22
  # Download model from Hugging Face
23
  def download_model():
@@ -61,6 +65,16 @@ transform = transforms.Compose([
61
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
62
  ])
63
 
 
 
 
 
 
 
 
 
 
 
64
  # Prediction function for an uploaded image
65
  def predict_from_image_url(image_url):
66
  try:
@@ -82,25 +96,19 @@ def predict_from_image_url(image_url):
82
 
83
  # Define class information
84
  class_info = {
85
- 0: {"name": "Fall Army Worm", "problem_id": "126"},
86
- 1: {"name": "Phosphorus Deficiency", "problem_id": "142"},
87
- 2: {"name": "Bacterial Leaf Blight", "problem_id": "203"}
88
- }
89
-
90
- # Construct probabilities dictionary
91
- probabilities_dict = {
92
- "Fall Army Worm": f"{probabilities[0]*100:.2f}%",
93
- "Phosphorus Deficiency": f"{probabilities[1]*100:.2f}%",
94
- "Bacterial Leaf Blight": f"{probabilities[2]*100:.2f}%"
95
  }
96
 
97
  # Validate predicted class index
98
  if predicted_class not in class_info:
99
  logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}")
100
  return {
101
- "status": "error",
102
- "message": f"Unexpected class prediction (index {predicted_class}). Model may not be configured correctly.",
103
- "probabilities": probabilities_dict
 
104
  }
105
 
106
  # Get predicted class info
@@ -109,6 +117,24 @@ def predict_from_image_url(image_url):
109
  problem_id = predicted_info["problem_id"]
110
  confidence = probabilities[predicted_class].item() # Confidence score for the predicted class
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # Check confidence threshold
113
  if confidence < CONFIDENCE_THRESHOLD:
114
  logger.warning(
@@ -116,29 +142,28 @@ def predict_from_image_url(image_url):
116
  f"for image URL: {image_url}"
117
  )
118
  return {
119
- "status": "uncertain",
120
- "message": (
121
- f"Prediction confidence ({confidence*100:.2f}%) is below the threshold ({CONFIDENCE_THRESHOLD*100}%). "
122
- "This image may not belong to the trained classes (Fall Army Worm, Phosphorus Deficiency, Bacterial Leaf Blight)."
123
- ),
124
  "predicted_class": predicted_name,
125
  "problem_id": problem_id,
126
- "confidence": f"{confidence*100:.2f}%",
127
- "probabilities": probabilities_dict
128
  }
129
 
130
  # Return successful prediction
131
  return {
132
- "status": "success",
133
  "predicted_class": predicted_name,
134
  "problem_id": problem_id,
135
- "confidence": f"{confidence*100:.2f}%",
136
- "probabilities": probabilities_dict
137
  }
138
 
139
  except Exception as e:
140
  logger.error(f"Error processing image URL {image_url}: {str(e)}")
141
- return {"status": "error", "message": str(e)}
 
 
 
 
 
142
 
143
  # Gradio interface
144
  demo = gr.Interface(
 
8
  import logging
9
  import requests
10
  from io import BytesIO
11
+ import numpy as np
12
 
13
  # Setup logging
14
  logging.basicConfig(level=logging.INFO)
 
18
  num_classes = 3
19
 
20
  # Confidence threshold for reliable predictions
21
+ CONFIDENCE_THRESHOLD = 0.8 # 80%
22
+
23
+ # Entropy threshold for flat probability distribution (to detect non-maize/rice images)
24
+ ENTROPY_THRESHOLD = 0.9 # Lower entropy means a more peaked distribution (more confident)
25
 
26
  # Download model from Hugging Face
27
  def download_model():
 
65
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
66
  ])
67
 
68
+ # Function to compute entropy of the probability distribution
69
+ def compute_entropy(probabilities):
70
+ probs = probabilities.numpy() # Convert to numpy array
71
+ # Avoid log(0) by adding a small epsilon
72
+ probs = np.clip(probs, 1e-10, 1.0)
73
+ entropy = -np.sum(probs * np.log(probs))
74
+ # Normalize entropy by the maximum possible entropy (log(num_classes))
75
+ max_entropy = np.log(num_classes)
76
+ return entropy / max_entropy
77
+
78
  # Prediction function for an uploaded image
79
  def predict_from_image_url(image_url):
80
  try:
 
96
 
97
  # Define class information
98
  class_info = {
99
+ 0: {"name": "Fall Army Worm", "problem_id": "126", "crop": "maize"},
100
+ 1: {"name": "Phosphorus Deficiency", "problem_id": "142", "crop": "maize"},
101
+ 2: {"name": "Bacterial Leaf Blight", "problem_id": "203", "crop": "rice"}
 
 
 
 
 
 
 
102
  }
103
 
104
  # Validate predicted class index
105
  if predicted_class not in class_info:
106
  logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}")
107
  return {
108
+ "status": "invalid",
109
+ "predicted_class": None,
110
+ "problem_id": None,
111
+ "confidence": None
112
  }
113
 
114
  # Get predicted class info
 
117
  problem_id = predicted_info["problem_id"]
118
  confidence = probabilities[predicted_class].item() # Confidence score for the predicted class
119
 
120
+ # Compute entropy of the probability distribution
121
+ entropy = compute_entropy(probabilities)
122
+ logger.info(f"Prediction entropy: {entropy:.4f}, confidence: {confidence:.4f} for image URL: {image_url}")
123
+
124
+ # Check if the image is likely maize or rice based on entropy and confidence
125
+ # High entropy (flat distribution) suggests the image may not be maize or rice
126
+ if entropy > ENTROPY_THRESHOLD:
127
+ logger.warning(
128
+ f"High entropy ({entropy:.4f} > {ENTROPY_THRESHOLD}) for image URL: {image_url}. "
129
+ "Image may not be of maize or rice."
130
+ )
131
+ return {
132
+ "status": "invalid",
133
+ "predicted_class": predicted_name,
134
+ "problem_id": problem_id,
135
+ "confidence": f"{confidence*100:.2f}%"
136
+ }
137
+
138
  # Check confidence threshold
139
  if confidence < CONFIDENCE_THRESHOLD:
140
  logger.warning(
 
142
  f"for image URL: {image_url}"
143
  )
144
  return {
145
+ "status": "invalid",
 
 
 
 
146
  "predicted_class": predicted_name,
147
  "problem_id": problem_id,
148
+ "confidence": f"{confidence*100:.2f}%"
 
149
  }
150
 
151
  # Return successful prediction
152
  return {
153
+ "status": "valid",
154
  "predicted_class": predicted_name,
155
  "problem_id": problem_id,
156
+ "confidence": f"{confidence*100:.2f}%"
 
157
  }
158
 
159
  except Exception as e:
160
  logger.error(f"Error processing image URL {image_url}: {str(e)}")
161
+ return {
162
+ "status": "invalid",
163
+ "predicted_class": None,
164
+ "problem_id": None,
165
+ "confidence": None
166
+ }
167
 
168
  # Gradio interface
169
  demo = gr.Interface(