Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import os
|
|
8 |
import logging
|
9 |
import requests
|
10 |
from io import BytesIO
|
|
|
11 |
|
12 |
# Setup logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
@@ -17,7 +18,10 @@ logger = logging.getLogger(__name__)
|
|
17 |
num_classes = 3
|
18 |
|
19 |
# Confidence threshold for reliable predictions
|
20 |
-
CONFIDENCE_THRESHOLD = 0.
|
|
|
|
|
|
|
21 |
|
22 |
# Download model from Hugging Face
|
23 |
def download_model():
|
@@ -61,6 +65,16 @@ transform = transforms.Compose([
|
|
61 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
62 |
])
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
# Prediction function for an uploaded image
|
65 |
def predict_from_image_url(image_url):
|
66 |
try:
|
@@ -82,25 +96,19 @@ def predict_from_image_url(image_url):
|
|
82 |
|
83 |
# Define class information
|
84 |
class_info = {
|
85 |
-
0: {"name": "Fall Army Worm", "problem_id": "126"},
|
86 |
-
1: {"name": "Phosphorus Deficiency", "problem_id": "142"},
|
87 |
-
2: {"name": "Bacterial Leaf Blight", "problem_id": "203"}
|
88 |
-
}
|
89 |
-
|
90 |
-
# Construct probabilities dictionary
|
91 |
-
probabilities_dict = {
|
92 |
-
"Fall Army Worm": f"{probabilities[0]*100:.2f}%",
|
93 |
-
"Phosphorus Deficiency": f"{probabilities[1]*100:.2f}%",
|
94 |
-
"Bacterial Leaf Blight": f"{probabilities[2]*100:.2f}%"
|
95 |
}
|
96 |
|
97 |
# Validate predicted class index
|
98 |
if predicted_class not in class_info:
|
99 |
logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}")
|
100 |
return {
|
101 |
-
"status": "
|
102 |
-
"
|
103 |
-
"
|
|
|
104 |
}
|
105 |
|
106 |
# Get predicted class info
|
@@ -109,6 +117,24 @@ def predict_from_image_url(image_url):
|
|
109 |
problem_id = predicted_info["problem_id"]
|
110 |
confidence = probabilities[predicted_class].item() # Confidence score for the predicted class
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
# Check confidence threshold
|
113 |
if confidence < CONFIDENCE_THRESHOLD:
|
114 |
logger.warning(
|
@@ -116,29 +142,28 @@ def predict_from_image_url(image_url):
|
|
116 |
f"for image URL: {image_url}"
|
117 |
)
|
118 |
return {
|
119 |
-
"status": "
|
120 |
-
"message": (
|
121 |
-
f"Prediction confidence ({confidence*100:.2f}%) is below the threshold ({CONFIDENCE_THRESHOLD*100}%). "
|
122 |
-
"This image may not belong to the trained classes (Fall Army Worm, Phosphorus Deficiency, Bacterial Leaf Blight)."
|
123 |
-
),
|
124 |
"predicted_class": predicted_name,
|
125 |
"problem_id": problem_id,
|
126 |
-
"confidence": f"{confidence*100:.2f}%"
|
127 |
-
"probabilities": probabilities_dict
|
128 |
}
|
129 |
|
130 |
# Return successful prediction
|
131 |
return {
|
132 |
-
"status": "
|
133 |
"predicted_class": predicted_name,
|
134 |
"problem_id": problem_id,
|
135 |
-
"confidence": f"{confidence*100:.2f}%"
|
136 |
-
"probabilities": probabilities_dict
|
137 |
}
|
138 |
|
139 |
except Exception as e:
|
140 |
logger.error(f"Error processing image URL {image_url}: {str(e)}")
|
141 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# Gradio interface
|
144 |
demo = gr.Interface(
|
|
|
8 |
import logging
|
9 |
import requests
|
10 |
from io import BytesIO
|
11 |
+
import numpy as np
|
12 |
|
13 |
# Setup logging
|
14 |
logging.basicConfig(level=logging.INFO)
|
|
|
18 |
num_classes = 3
|
19 |
|
20 |
# Confidence threshold for reliable predictions
|
21 |
+
CONFIDENCE_THRESHOLD = 0.8 # 80%
|
22 |
+
|
23 |
+
# Entropy threshold for flat probability distribution (to detect non-maize/rice images)
|
24 |
+
ENTROPY_THRESHOLD = 0.9 # Lower entropy means a more peaked distribution (more confident)
|
25 |
|
26 |
# Download model from Hugging Face
|
27 |
def download_model():
|
|
|
65 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
66 |
])
|
67 |
|
68 |
+
# Function to compute entropy of the probability distribution
|
69 |
+
def compute_entropy(probabilities):
|
70 |
+
probs = probabilities.numpy() # Convert to numpy array
|
71 |
+
# Avoid log(0) by adding a small epsilon
|
72 |
+
probs = np.clip(probs, 1e-10, 1.0)
|
73 |
+
entropy = -np.sum(probs * np.log(probs))
|
74 |
+
# Normalize entropy by the maximum possible entropy (log(num_classes))
|
75 |
+
max_entropy = np.log(num_classes)
|
76 |
+
return entropy / max_entropy
|
77 |
+
|
78 |
# Prediction function for an uploaded image
|
79 |
def predict_from_image_url(image_url):
|
80 |
try:
|
|
|
96 |
|
97 |
# Define class information
|
98 |
class_info = {
|
99 |
+
0: {"name": "Fall Army Worm", "problem_id": "126", "crop": "maize"},
|
100 |
+
1: {"name": "Phosphorus Deficiency", "problem_id": "142", "crop": "maize"},
|
101 |
+
2: {"name": "Bacterial Leaf Blight", "problem_id": "203", "crop": "rice"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
}
|
103 |
|
104 |
# Validate predicted class index
|
105 |
if predicted_class not in class_info:
|
106 |
logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}")
|
107 |
return {
|
108 |
+
"status": "invalid",
|
109 |
+
"predicted_class": None,
|
110 |
+
"problem_id": None,
|
111 |
+
"confidence": None
|
112 |
}
|
113 |
|
114 |
# Get predicted class info
|
|
|
117 |
problem_id = predicted_info["problem_id"]
|
118 |
confidence = probabilities[predicted_class].item() # Confidence score for the predicted class
|
119 |
|
120 |
+
# Compute entropy of the probability distribution
|
121 |
+
entropy = compute_entropy(probabilities)
|
122 |
+
logger.info(f"Prediction entropy: {entropy:.4f}, confidence: {confidence:.4f} for image URL: {image_url}")
|
123 |
+
|
124 |
+
# Check if the image is likely maize or rice based on entropy and confidence
|
125 |
+
# High entropy (flat distribution) suggests the image may not be maize or rice
|
126 |
+
if entropy > ENTROPY_THRESHOLD:
|
127 |
+
logger.warning(
|
128 |
+
f"High entropy ({entropy:.4f} > {ENTROPY_THRESHOLD}) for image URL: {image_url}. "
|
129 |
+
"Image may not be of maize or rice."
|
130 |
+
)
|
131 |
+
return {
|
132 |
+
"status": "invalid",
|
133 |
+
"predicted_class": predicted_name,
|
134 |
+
"problem_id": problem_id,
|
135 |
+
"confidence": f"{confidence*100:.2f}%"
|
136 |
+
}
|
137 |
+
|
138 |
# Check confidence threshold
|
139 |
if confidence < CONFIDENCE_THRESHOLD:
|
140 |
logger.warning(
|
|
|
142 |
f"for image URL: {image_url}"
|
143 |
)
|
144 |
return {
|
145 |
+
"status": "invalid",
|
|
|
|
|
|
|
|
|
146 |
"predicted_class": predicted_name,
|
147 |
"problem_id": problem_id,
|
148 |
+
"confidence": f"{confidence*100:.2f}%"
|
|
|
149 |
}
|
150 |
|
151 |
# Return successful prediction
|
152 |
return {
|
153 |
+
"status": "valid",
|
154 |
"predicted_class": predicted_name,
|
155 |
"problem_id": problem_id,
|
156 |
+
"confidence": f"{confidence*100:.2f}%"
|
|
|
157 |
}
|
158 |
|
159 |
except Exception as e:
|
160 |
logger.error(f"Error processing image URL {image_url}: {str(e)}")
|
161 |
+
return {
|
162 |
+
"status": "invalid",
|
163 |
+
"predicted_class": None,
|
164 |
+
"problem_id": None,
|
165 |
+
"confidence": None
|
166 |
+
}
|
167 |
|
168 |
# Gradio interface
|
169 |
demo = gr.Interface(
|