jays009 commited on
Commit
694c989
·
verified ·
1 Parent(s): d2df5a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -30
app.py CHANGED
@@ -11,10 +11,14 @@ from io import BytesIO
11
 
12
  # Setup logging
13
  logging.basicConfig(level=logging.INFO)
 
14
 
15
  # Define the number of classes
16
  num_classes = 3
17
 
 
 
 
18
  # Download model from Hugging Face
19
  def download_model():
20
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
@@ -26,7 +30,7 @@ def load_model(model_path):
26
  num_features = model.fc.in_features
27
  model.fc = nn.Sequential(
28
  nn.Dropout(0.5),
29
- nn.Linear(num_features, 3) # 3 classes
30
  )
31
 
32
  # Load the checkpoint
@@ -67,48 +71,74 @@ def predict_from_image_url(image_url):
67
 
68
  # Apply transformations
69
  image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
70
- print(f"Input image tensor shape: {image_tensor.shape}") # Debug: Should be [1, 3, 224, 224]
71
 
72
  # Perform prediction
73
  with torch.no_grad():
74
  outputs = model(image_tensor) # Shape: [1, 3]
75
- print(f"Model output shape: {outputs.shape}") # Debug: Should be [1, 3]
76
  probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
77
  predicted_class = torch.argmax(outputs, dim=1).item()
78
 
79
- # Interpret the result
80
- if predicted_class == 0:
81
- return {
82
- "result": "The photo is of Fall Army Worm with problem ID 126.",
83
- "probabilities": {
84
- "Fall Army Worm": f"{probabilities[0]*100:.2f}%",
85
- "Phosphorus Deficiency": f"{probabilities[1]*100:.2f}%",
86
- "Bacterial Leaf Blight": f"{probabilities[2]*100:.2f}%"
87
- }
88
- }
89
- elif predicted_class == 1:
 
 
 
 
 
 
90
  return {
91
- "result": "The photo shows symptoms of Phosphorus Deficiency with Problem ID 142.",
92
- "probabilities": {
93
- "Fall Army Worm": f"{probabilities[0]*100:.2f}%",
94
- "Phosphorus Deficiency": f"{probabilities[1]*100:.2f}%",
95
- "Bacterial Leaf Blight": f"{probabilities[2]*100:.2f}%"
96
- }
97
  }
98
- elif predicted_class == 2:
 
 
 
 
 
 
 
 
 
 
 
 
99
  return {
100
- "result": "The photo shows symptoms of Bacterial Leaf Blight with Problem ID 203.",
101
- "probabilities": {
102
- "Fall Army Worm": f"{probabilities[0]*100:.2f}%",
103
- "Phosphorus Deficiency": f"{probabilities[1]*100:.2f}%",
104
- "Bacterial Leaf Blight": f"{probabilities[2]*100:.2f}%"
105
- }
 
 
 
106
  }
107
- else:
108
- return {"error": "Unexpected class prediction."}
 
 
 
 
 
 
 
109
 
110
  except Exception as e:
111
- return {"error": str(e)}
 
112
 
113
  # Gradio interface
114
  demo = gr.Interface(
 
11
 
12
  # Setup logging
13
  logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
 
16
  # Define the number of classes
17
  num_classes = 3
18
 
19
+ # Confidence threshold for reliable predictions
20
+ CONFIDENCE_THRESHOLD = 0.5 # 50%
21
+
22
  # Download model from Hugging Face
23
  def download_model():
24
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
 
30
  num_features = model.fc.in_features
31
  model.fc = nn.Sequential(
32
  nn.Dropout(0.5),
33
+ nn.Linear(num_features, num_classes) # 3 classes
34
  )
35
 
36
  # Load the checkpoint
 
71
 
72
  # Apply transformations
73
  image_tensor = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
74
+ logger.info(f"Input image tensor shape: {image_tensor.shape}")
75
 
76
  # Perform prediction
77
  with torch.no_grad():
78
  outputs = model(image_tensor) # Shape: [1, 3]
79
+ logger.info(f"Model output shape: {outputs.shape}")
80
  probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
81
  predicted_class = torch.argmax(outputs, dim=1).item()
82
 
83
+ # Define class information
84
+ class_info = {
85
+ 0: {"name": "Fall Army Worm", "problem_id": "126"},
86
+ 1: {"name": "Phosphorus Deficiency", "problem_id": "142"},
87
+ 2: {"name": "Bacterial Leaf Blight", "problem_id": "203"}
88
+ }
89
+
90
+ # Construct probabilities dictionary
91
+ probabilities_dict = {
92
+ "Fall Army Worm": f"{probabilities[0]*100:.2f}%",
93
+ "Phosphorus Deficiency": f"{probabilities[1]*100:.2f}%",
94
+ "Bacterial Leaf Blight": f"{probabilities[2]*100:.2f}%"
95
+ }
96
+
97
+ # Validate predicted class index
98
+ if predicted_class not in class_info:
99
+ logger.warning(f"Unexpected class prediction: {predicted_class} for image URL: {image_url}")
100
  return {
101
+ "status": "error",
102
+ "message": f"Unexpected class prediction (index {predicted_class}). Model may not be configured correctly.",
103
+ "probabilities": probabilities_dict
 
 
 
104
  }
105
+
106
+ # Get predicted class info
107
+ predicted_info = class_info[predicted_class]
108
+ predicted_name = predicted_info["name"]
109
+ problem_id = predicted_info["problem_id"]
110
+ confidence = probabilities[predicted_class].item() # Confidence score for the predicted class
111
+
112
+ # Check confidence threshold
113
+ if confidence < CONFIDENCE_THRESHOLD:
114
+ logger.warning(
115
+ f"Low confidence prediction: {predicted_name} with confidence {confidence*100:.2f}% "
116
+ f"for image URL: {image_url}"
117
+ )
118
  return {
119
+ "status": "uncertain",
120
+ "message": (
121
+ f"Prediction confidence ({confidence*100:.2f}%) is below the threshold ({CONFIDENCE_THRESHOLD*100}%). "
122
+ "This image may not belong to the trained classes (Fall Army Worm, Phosphorus Deficiency, Bacterial Leaf Blight)."
123
+ ),
124
+ "predicted_class": predicted_name,
125
+ "problem_id": problem_id,
126
+ "confidence": f"{confidence*100:.2f}%",
127
+ "probabilities": probabilities_dict
128
  }
129
+
130
+ # Return successful prediction
131
+ return {
132
+ "status": "success",
133
+ "predicted_class": predicted_name,
134
+ "problem_id": problem_id,
135
+ "confidence": f"{confidence*100:.2f}%",
136
+ "probabilities": probabilities_dict
137
+ }
138
 
139
  except Exception as e:
140
+ logger.error(f"Error processing image URL {image_url}: {str(e)}")
141
+ return {"status": "error", "message": str(e)}
142
 
143
  # Gradio interface
144
  demo = gr.Interface(