jays009 commited on
Commit
5aef664
·
verified ·
1 Parent(s): e4b1309

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -37
app.py CHANGED
@@ -17,19 +17,16 @@ logger = logging.getLogger(__name__)
17
  # Define the number of classes
18
  num_classes = 3
19
 
20
- # Confidence threshold for reliable predictions
21
  CONFIDENCE_THRESHOLD = 0.8 # 80%
22
 
23
- # Entropy threshold for flat probability distribution (to detect non-maize/rice images)
24
- ENTROPY_THRESHOLD = 0.9 # Lower entropy means a more peaked distribution (more confident)
25
-
26
  # Download model from Hugging Face
27
  def download_model():
28
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
29
  return model_path
30
 
31
- # Load the model from Hugging Face
32
- def load_model(model_path):
33
  model = models.resnet50(pretrained=False)
34
  num_features = model.fc.in_features
35
  model.fc = nn.Sequential(
@@ -55,7 +52,7 @@ def load_model(model_path):
55
 
56
  # Path to your model
57
  model_path = download_model()
58
- model = load_model(model_path)
59
 
60
  # Define the transformation for the input image
61
  transform = transforms.Compose([
@@ -65,16 +62,6 @@ transform = transforms.Compose([
65
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
66
  ])
67
 
68
- # Function to compute entropy of the probability distribution
69
- def compute_entropy(probabilities):
70
- probs = probabilities.numpy() # Convert to numpy array
71
- # Avoid log(0) by adding a small epsilon
72
- probs = np.clip(probs, 1e-10, 1.0)
73
- entropy = -np.sum(probs * np.log(probs))
74
- # Normalize entropy by the maximum possible entropy (log(num_classes))
75
- max_entropy = np.log(num_classes)
76
- return entropy / max_entropy
77
-
78
  # Prediction function for an uploaded image
79
  def predict_from_image_url(image_url):
80
  try:
@@ -89,9 +76,13 @@ def predict_from_image_url(image_url):
89
 
90
  # Perform prediction
91
  with torch.no_grad():
92
- outputs = model(image_tensor) # Shape: [1, 3]
93
  logger.info(f"Model output shape: {outputs.shape}")
 
 
94
  probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
 
 
95
  predicted_class = torch.argmax(outputs, dim=1).item()
96
 
97
  # Define class information
@@ -115,25 +106,7 @@ def predict_from_image_url(image_url):
115
  predicted_info = class_info[predicted_class]
116
  predicted_name = predicted_info["name"]
117
  problem_id = predicted_info["problem_id"]
118
- confidence = probabilities[predicted_class].item() # Confidence score for the predicted class
119
-
120
- # Compute entropy of the probability distribution
121
- entropy = compute_entropy(probabilities)
122
- logger.info(f"Prediction entropy: {entropy:.4f}, confidence: {confidence:.4f} for image URL: {image_url}")
123
-
124
- # Check if the image is likely maize or rice based on entropy and confidence
125
- # High entropy (flat distribution) suggests the image may not be maize or rice
126
- if entropy > ENTROPY_THRESHOLD:
127
- logger.warning(
128
- f"High entropy ({entropy:.4f} > {ENTROPY_THRESHOLD}) for image URL: {image_url}. "
129
- "Image may not be of maize or rice."
130
- )
131
- return {
132
- "status": "invalid",
133
- "predicted_class": predicted_name,
134
- "problem_id": problem_id,
135
- "confidence": f"{confidence*100:.2f}%"
136
- }
137
 
138
  # Check confidence threshold
139
  if confidence < CONFIDENCE_THRESHOLD:
 
17
  # Define the number of classes
18
  num_classes = 3
19
 
20
+ # Confidence threshold for main model predictions
21
  CONFIDENCE_THRESHOLD = 0.8 # 80%
22
 
 
 
 
23
  # Download model from Hugging Face
24
  def download_model():
25
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
26
  return model_path
27
 
28
+ # Load the main model from Hugging Face
29
+ def load_main_model(model_path):
30
  model = models.resnet50(pretrained=False)
31
  num_features = model.fc.in_features
32
  model.fc = nn.Sequential(
 
52
 
53
  # Path to your model
54
  model_path = download_model()
55
+ main_model = load_main_model(model_path)
56
 
57
  # Define the transformation for the input image
58
  transform = transforms.Compose([
 
62
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
63
  ])
64
 
 
 
 
 
 
 
 
 
 
 
65
  # Prediction function for an uploaded image
66
  def predict_from_image_url(image_url):
67
  try:
 
76
 
77
  # Perform prediction
78
  with torch.no_grad():
79
+ outputs = main_model(image_tensor) # Shape: [1, 3]
80
  logger.info(f"Model output shape: {outputs.shape}")
81
+ # Log raw logits
82
+ logger.info(f"Raw logits: {outputs[0].numpy()}")
83
  probabilities = torch.softmax(outputs, dim=1)[0] # Convert to probabilities
84
+ # Log softmax probabilities
85
+ logger.info(f"Softmax probabilities: {probabilities.numpy()}")
86
  predicted_class = torch.argmax(outputs, dim=1).item()
87
 
88
  # Define class information
 
106
  predicted_info = class_info[predicted_class]
107
  predicted_name = predicted_info["name"]
108
  problem_id = predicted_info["problem_id"]
109
+ confidence = probabilities[predicted_class].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # Check confidence threshold
112
  if confidence < CONFIDENCE_THRESHOLD: