trapezius60 commited on
Commit
285ba5e
·
verified ·
1 Parent(s): c295347

Update app.py

Browse files

version 7/8/2025


# ✅ Training Script Summary:
# - Use torchvision.datasets.ImageFolder with folders "Edible/" and "Poisonous/"
# - Apply data augmentation (flip, jitter, rotate, etc.)
# - Split into train/val and load with DataLoader
# - Use pretrained ResNet50 with replaced final FC layer
# - Train using CrossEntropyLoss and Adam optimizer
# - Save model using torch.save(...)

Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -2,15 +2,14 @@
2
  from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
- from torchvision import models
6
- import gradio as gr
7
- import webbrowser
8
- from rembg import remove # 🆕 Background removal
9
 
10
  # 🔧 Set device
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- # 📦 Load the trained model
14
  model = models.resnet50(pretrained=False)
15
  model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
16
  model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
@@ -32,13 +31,28 @@ transform = transforms.Compose([
32
  [0.229, 0.224, 0.225])
33
  ])
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # 🧠 Classification function with validation
36
  CONFIDENCE_THRESHOLD = 85.0 # Minimum confidence considered safe enough to show suggestion
37
 
38
  def classify_mushroom(image: Image.Image):
39
  try:
40
- image = image.convert("RGBA") # Required format for rembg
41
- image_no_bg = remove(image).convert("RGB") # 🆕 Remove background
 
 
 
42
  tensor = transform(image_no_bg).unsqueeze(0).to(device)
43
 
44
  with torch.no_grad():
@@ -49,7 +63,6 @@ def classify_mushroom(image: Image.Image):
49
 
50
  suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."
51
 
52
- # Check confidence and warn
53
  if score < CONFIDENCE_THRESHOLD:
54
  label = "Uncertain"
55
 
@@ -59,9 +72,6 @@ def classify_mushroom(image: Image.Image):
59
  print(f"❌ Error: {e}")
60
  return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."
61
 
62
- # 🔗 Open user manual (link version) 🆕
63
- MANUAL_URL = "https://drive.google.com/drive/folders/19lUCEaLstrRjCzqpDlWErhRd1EXWUGbf?usp=sharing"
64
- manual_link = f"<a href=\"{MANUAL_URL}\" target=\"_blank\">📄 Open User Manual</a>"
65
  # 🎛️ Gradio UI
66
  if __name__ == "__main__":
67
  with gr.Blocks() as demo:
@@ -69,7 +79,7 @@ if __name__ == "__main__":
69
  gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")
70
 
71
  with gr.Row():
72
- image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image") # 🆕 safer source usage
73
  with gr.Column():
74
  label_en = gr.Textbox(label="🧬 Prediction (English)")
75
  label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
@@ -85,7 +95,8 @@ if __name__ == "__main__":
85
  )
86
 
87
  gr.Markdown("---")
88
- gr.Markdown(manual_link) # 🆕 Display clickable user manual
89
- gr.Markdown("App version: 1.0.2 | Updated: August 2025") # 🆕 Version bump
90
 
91
  demo.launch()
 
 
 
2
  from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
+ from torchvision import models\import gradio as gr
6
+ from rembg import remove # Background removal
7
+ from transformers import pipeline # For non-mushroom detection
 
8
 
9
  # 🔧 Set device
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # 📦 Load your fine-tuned model
13
  model = models.resnet50(pretrained=False)
14
  model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
15
  model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
 
31
  [0.229, 0.224, 0.225])
32
  ])
33
 
34
+ # 🔍 Pretrained image classifier for screening non-mushroom
35
+ label_detector = pipeline("image-classification", model="microsoft/resnet-50")
36
+ mushroom_keywords = ["mushroom", "agaric", "amanita", "fungus", "earthstar", "toadstool"]
37
+
38
+ def is_mushroom_image(image):
39
+ try:
40
+ result = label_detector(image)
41
+ top_label = result[0]["label"].lower()
42
+ return any(keyword in top_label for keyword in mushroom_keywords)
43
+ except:
44
+ return False
45
+
46
  # 🧠 Classification function with validation
47
  CONFIDENCE_THRESHOLD = 85.0 # Minimum confidence considered safe enough to show suggestion
48
 
49
  def classify_mushroom(image: Image.Image):
50
  try:
51
+ if not is_mushroom_image(image):
52
+ return "Not a mushroom", "ไม่ใช่เห็ด", "0.00%", "Please upload a mushroom photo."
53
+
54
+ image = image.convert("RGBA")
55
+ image_no_bg = remove(image).convert("RGB")
56
  tensor = transform(image_no_bg).unsqueeze(0).to(device)
57
 
58
  with torch.no_grad():
 
63
 
64
  suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."
65
 
 
66
  if score < CONFIDENCE_THRESHOLD:
67
  label = "Uncertain"
68
 
 
72
  print(f"❌ Error: {e}")
73
  return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."
74
 
 
 
 
75
  # 🎛️ Gradio UI
76
  if __name__ == "__main__":
77
  with gr.Blocks() as demo:
 
79
  gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")
80
 
81
  with gr.Row():
82
+ image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image")
83
  with gr.Column():
84
  label_en = gr.Textbox(label="🧬 Prediction (English)")
85
  label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
 
95
  )
96
 
97
  gr.Markdown("---")
98
+ gr.Markdown("App version: 1.1.0 | Updated: August 2025")
 
99
 
100
  demo.launch()
101
+
102
+