Spaces:
Build error
Build error
Update app.py
Browse filesversion 7/8/2025
# ✅ Training Script Summary:
# - Use torchvision.datasets.ImageFolder with folders "Edible/" and "Poisonous/"
# - Apply data augmentation (flip, jitter, rotate, etc.)
# - Split into train/val and load with DataLoader
# - Use pretrained ResNet50 with replaced final FC layer
# - Train using CrossEntropyLoss and Adam optimizer
# - Save model using torch.save(...)
app.py
CHANGED
@@ -2,15 +2,14 @@
|
|
2 |
from PIL import Image
|
3 |
import torch
|
4 |
import torchvision.transforms as transforms
|
5 |
-
from torchvision import models
|
6 |
-
import
|
7 |
-
import
|
8 |
-
from rembg import remove # 🆕 Background removal
|
9 |
|
10 |
# 🔧 Set device
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
-
# 📦 Load
|
14 |
model = models.resnet50(pretrained=False)
|
15 |
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
|
16 |
model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
|
@@ -32,13 +31,28 @@ transform = transforms.Compose([
|
|
32 |
[0.229, 0.224, 0.225])
|
33 |
])
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
# 🧠 Classification function with validation
|
36 |
CONFIDENCE_THRESHOLD = 85.0 # Minimum confidence considered safe enough to show suggestion
|
37 |
|
38 |
def classify_mushroom(image: Image.Image):
|
39 |
try:
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
42 |
tensor = transform(image_no_bg).unsqueeze(0).to(device)
|
43 |
|
44 |
with torch.no_grad():
|
@@ -49,7 +63,6 @@ def classify_mushroom(image: Image.Image):
|
|
49 |
|
50 |
suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."
|
51 |
|
52 |
-
# Check confidence and warn
|
53 |
if score < CONFIDENCE_THRESHOLD:
|
54 |
label = "Uncertain"
|
55 |
|
@@ -59,9 +72,6 @@ def classify_mushroom(image: Image.Image):
|
|
59 |
print(f"❌ Error: {e}")
|
60 |
return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."
|
61 |
|
62 |
-
# 🔗 Open user manual (link version) 🆕
|
63 |
-
MANUAL_URL = "https://drive.google.com/drive/folders/19lUCEaLstrRjCzqpDlWErhRd1EXWUGbf?usp=sharing"
|
64 |
-
manual_link = f"<a href=\"{MANUAL_URL}\" target=\"_blank\">📄 Open User Manual</a>"
|
65 |
# 🎛️ Gradio UI
|
66 |
if __name__ == "__main__":
|
67 |
with gr.Blocks() as demo:
|
@@ -69,7 +79,7 @@ if __name__ == "__main__":
|
|
69 |
gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")
|
70 |
|
71 |
with gr.Row():
|
72 |
-
image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image")
|
73 |
with gr.Column():
|
74 |
label_en = gr.Textbox(label="🧬 Prediction (English)")
|
75 |
label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
|
@@ -85,7 +95,8 @@ if __name__ == "__main__":
|
|
85 |
)
|
86 |
|
87 |
gr.Markdown("---")
|
88 |
-
gr.Markdown(
|
89 |
-
gr.Markdown("App version: 1.0.2 | Updated: August 2025") # 🆕 Version bump
|
90 |
|
91 |
demo.launch()
|
|
|
|
|
|
2 |
from PIL import Image
|
3 |
import torch
|
4 |
import torchvision.transforms as transforms
|
5 |
+
from torchvision import models\import gradio as gr
|
6 |
+
from rembg import remove # Background removal
|
7 |
+
from transformers import pipeline # For non-mushroom detection
|
|
|
8 |
|
9 |
# 🔧 Set device
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
|
12 |
+
# 📦 Load your fine-tuned model
|
13 |
model = models.resnet50(pretrained=False)
|
14 |
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
|
15 |
model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
|
|
|
31 |
[0.229, 0.224, 0.225])
|
32 |
])
|
33 |
|
34 |
+
# 🔍 Pretrained image classifier for screening non-mushroom
|
35 |
+
label_detector = pipeline("image-classification", model="microsoft/resnet-50")
|
36 |
+
mushroom_keywords = ["mushroom", "agaric", "amanita", "fungus", "earthstar", "toadstool"]
|
37 |
+
|
38 |
+
def is_mushroom_image(image):
|
39 |
+
try:
|
40 |
+
result = label_detector(image)
|
41 |
+
top_label = result[0]["label"].lower()
|
42 |
+
return any(keyword in top_label for keyword in mushroom_keywords)
|
43 |
+
except:
|
44 |
+
return False
|
45 |
+
|
46 |
# 🧠 Classification function with validation
|
47 |
CONFIDENCE_THRESHOLD = 85.0 # Minimum confidence considered safe enough to show suggestion
|
48 |
|
49 |
def classify_mushroom(image: Image.Image):
|
50 |
try:
|
51 |
+
if not is_mushroom_image(image):
|
52 |
+
return "Not a mushroom", "ไม่ใช่เห็ด", "0.00%", "Please upload a mushroom photo."
|
53 |
+
|
54 |
+
image = image.convert("RGBA")
|
55 |
+
image_no_bg = remove(image).convert("RGB")
|
56 |
tensor = transform(image_no_bg).unsqueeze(0).to(device)
|
57 |
|
58 |
with torch.no_grad():
|
|
|
63 |
|
64 |
suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."
|
65 |
|
|
|
66 |
if score < CONFIDENCE_THRESHOLD:
|
67 |
label = "Uncertain"
|
68 |
|
|
|
72 |
print(f"❌ Error: {e}")
|
73 |
return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."
|
74 |
|
|
|
|
|
|
|
75 |
# 🎛️ Gradio UI
|
76 |
if __name__ == "__main__":
|
77 |
with gr.Blocks() as demo:
|
|
|
79 |
gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")
|
80 |
|
81 |
with gr.Row():
|
82 |
+
image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image")
|
83 |
with gr.Column():
|
84 |
label_en = gr.Textbox(label="🧬 Prediction (English)")
|
85 |
label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
|
|
|
95 |
)
|
96 |
|
97 |
gr.Markdown("---")
|
98 |
+
gr.Markdown("App version: 1.1.0 | Updated: August 2025")
|
|
|
99 |
|
100 |
demo.launch()
|
101 |
+
|
102 |
+
|