trapezius60 commited on
Commit
bf3c131
·
verified ·
1 Parent(s): ef934c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -28
app.py CHANGED
@@ -1,32 +1,29 @@
 
1
  from PIL import Image
2
  import torch
3
  import torchvision.transforms as transforms
4
  from torchvision import models
5
  import gradio as gr
6
- from rembg import remove # 🆕 Background removal
7
- import io
8
- import datetime
9
 
10
  # 🔧 Set device
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- # 📦 Load your fine-tuned model
14
  model = models.resnet50(pretrained=False)
15
  model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
16
  model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
17
  model = model.to(device)
18
  model.eval()
19
 
20
- # 🏷️ Class names
21
  class_names = ['Edible', 'Poisonous']
22
-
23
- # 🍄 Mapping for more detailed species
24
  mushroom_species = {
25
  "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
26
  "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
27
  }
28
 
29
- # 🎨 Image preprocessing (must match training)
30
  transform = transforms.Compose([
31
  transforms.Resize((224, 224)),
32
  transforms.ToTensor(),
@@ -34,16 +31,13 @@ transform = transforms.Compose([
34
  [0.229, 0.224, 0.225])
35
  ])
36
 
37
- # 🧠 Prediction function
 
 
38
  def classify_mushroom(image: Image.Image):
39
  try:
40
  image = image.convert("RGB")
41
-
42
- # 🆕 Remove background
43
- bg_removed = remove(image)
44
-
45
- # 🆗 Prepare cleaned image for model
46
- tensor = transform(bg_removed).unsqueeze(0).to(device)
47
 
48
  with torch.no_grad():
49
  outputs = model(tensor)
@@ -51,42 +45,54 @@ def classify_mushroom(image: Image.Image):
51
  label = class_names[predicted.item()]
52
  score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
53
 
54
- suggestion = mushroom_species[label]
55
 
56
- return label, "กินได้" if label == "Edible" else "พิษ", f"{score:.2f}%", suggestion, bg_removed
 
 
 
 
57
 
58
  except Exception as e:
59
  print(f"❌ Error: {e}")
60
- return "Error", "ผิดพลาด", "N/A", "N/A", image
61
 
62
- # 📆 Version Info
63
- version_info = "**Version:** 1.0.1 \n**Last updated:** " + datetime.datetime.now().strftime("%Y-%m-%d")
 
 
 
64
 
65
  # 🎛️ Gradio UI
66
  if __name__ == "__main__":
67
  with gr.Blocks() as demo:
68
  gr.Markdown("## 🍄 Mushroom Safety Classifier")
69
- gr.Markdown("Upload a mushroom photo to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดเพื่อทำนายว่าเห็ดกินได้หรือมีพิษ")
70
 
71
  with gr.Row():
 
72
  with gr.Column():
73
- image_input = gr.Image(type="pil", label="📷 Upload Mushroom Image")
74
- cleaned_image = gr.Image(label="🧼 Background Removed Image") # 🆕 Show cleaned image
75
- with gr.Column():
76
- label_en = gr.Textbox(label="🧠 Prediction (English)")
77
- label_th = gr.Textbox(label="🗣️ คำทำนาย (ภาษาไทย)")
78
  confidence = gr.Textbox(label="📶 Confidence Score")
79
  label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
80
 
81
  classify_btn = gr.Button("🔍 Classify")
 
82
 
83
  classify_btn.click(
84
  fn=classify_mushroom,
85
  inputs=image_input,
86
- outputs=[label_en, label_th, confidence, label_hint, cleaned_image]
 
 
 
 
 
 
87
  )
88
 
89
  gr.Markdown("---")
90
- gr.Markdown(version_info) # 📅 Version and update info
91
 
92
  demo.launch()
 
1
+ # ✅ Import required libraries
2
  from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
  from torchvision import models
6
  import gradio as gr
7
+ import webbrowser
 
 
8
 
9
  # 🔧 Set device
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # 📦 Load the trained model
13
  model = models.resnet50(pretrained=False)
14
  model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
15
  model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
16
  model = model.to(device)
17
  model.eval()
18
 
19
+ # 🏷️ Class names and species mapping
20
  class_names = ['Edible', 'Poisonous']
 
 
21
  mushroom_species = {
22
  "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
23
  "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
24
  }
25
 
26
+ # 🎨 Image preprocessing
27
  transform = transforms.Compose([
28
  transforms.Resize((224, 224)),
29
  transforms.ToTensor(),
 
31
  [0.229, 0.224, 0.225])
32
  ])
33
 
34
+ # 🧠 Classification function with validation
35
+ CONFIDENCE_THRESHOLD = 85.0 # Minimum confidence considered safe enough to show suggestion
36
+
37
  def classify_mushroom(image: Image.Image):
38
  try:
39
  image = image.convert("RGB")
40
+ tensor = transform(image).unsqueeze(0).to(device)
 
 
 
 
 
41
 
42
  with torch.no_grad():
43
  outputs = model(tensor)
 
45
  label = class_names[predicted.item()]
46
  score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
47
 
48
+ suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."
49
 
50
+ # Check confidence and warn
51
+ if score < CONFIDENCE_THRESHOLD:
52
+ label = "Uncertain"
53
+
54
+ return label, ("กินได้" if label == "Edible" else ("พิษ" if label == "Poisonous" else "ไม่มั่นใจ")), f"{score:.2f}%", suggestion
55
 
56
  except Exception as e:
57
  print(f"❌ Error: {e}")
58
+ return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."
59
 
60
+ # 🔗 Open user manual
61
+ MANUAL_URL = "https://drive.google.com/drive/folders/19lUCEaLstrRjCzqpDlWErhRd1EXWUGbf?usp=sharing"
62
+ def open_manual():
63
+ webbrowser.open(MANUAL_URL)
64
+ return "", "", "", "Opening user manual..."
65
 
66
  # 🎛️ Gradio UI
67
  if __name__ == "__main__":
68
  with gr.Blocks() as demo:
69
  gr.Markdown("## 🍄 Mushroom Safety Classifier")
70
+ gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")
71
 
72
  with gr.Row():
73
+ image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image", source="upload", tool="editor")
74
  with gr.Column():
75
+ label_en = gr.Textbox(label="🧬 Prediction (English)")
76
+ label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
 
 
 
77
  confidence = gr.Textbox(label="📶 Confidence Score")
78
  label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
79
 
80
  classify_btn = gr.Button("🔍 Classify")
81
+ manual_btn = gr.Button("Open User Manual 📄")
82
 
83
  classify_btn.click(
84
  fn=classify_mushroom,
85
  inputs=image_input,
86
+ outputs=[label_en, label_th, confidence, label_hint]
87
+ )
88
+
89
+ manual_btn.click(
90
+ fn=open_manual,
91
+ inputs=[],
92
+ outputs=[label_en, label_th, confidence, label_hint]
93
  )
94
 
95
  gr.Markdown("---")
96
+ gr.Markdown("App version: 1.0.1 | Updated: August 2025")
97
 
98
  demo.launch()