File size: 4,068 Bytes
bf3c131
21135f4
 
 
273f35a
 
285ba5e
21135f4
 
 
 
285ba5e
ae7b3ca
5291024
 
 
 
3e6e46c
21135f4
 
 
 
bf3c131
21135f4
 
 
 
 
 
bf3c131
d2d1936
 
21135f4
 
 
 
 
40f5fb8
bf3c131
 
21135f4
 
285ba5e
 
e3a1380
21135f4
 
 
 
 
 
 
bf3c131
21135f4
bf3c131
 
 
d2d1936
21135f4
 
 
bf3c131
21135f4
5c573bd
 
3708e8e
 
b54e5c4
5c573bd
21135f4
 
 
 
bf3c131
21135f4
 
285ba5e
c295347
bf3c131
 
21135f4
 
 
 
 
 
 
 
bf3c131
 
 
21135f4
5c573bd
6b704c2
40f5fb8
21135f4
 
285ba5e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# ✅ Import required libraries
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import models
import gradio as gr
from rembg import remove  # Background removal

# 🔧 Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 📦 Load your fine-tuned model
model = models.resnet50(weights=None)  # No pretrained weights
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.5),
    torch.nn.Linear(model.fc.in_features, 2)
)

model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
model = model.to(device)
model.eval()

# 🏷️ Class names and species mapping
class_names = ['Edible', 'Poisonous']
mushroom_species = {
    "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
    "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
}

# 🎨 Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# 🧠 Classification function with confidence filtering
CONFIDENCE_THRESHOLD = 85.0  # Minimum confidence considered safe enough to show suggestion

def classify_mushroom(image: Image.Image):
    try:
        image = image.convert("RGBA")
        image_no_bg = remove(image).convert("RGB")
        tensor = transform(image_no_bg).unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model(tensor)
            _, predicted = torch.max(outputs, 1)
            label = class_names[predicted.item()]
            score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100

        suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."

        if score < CONFIDENCE_THRESHOLD:
            label = "Uncertain"

        return label, ("กินได้" if label == "Edible" else ("พิษ" if label == "Poisonous" else "ไม่มั่นใจ")), f"{score:.2f}%", suggestion

    except Exception as e:
        print(f"❌ Error: {e}")
        return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."

# 🔗 Open user manual (link version) 🆕
MANUAL_URL = "https://drive.google.com/drive/folders/19lUCEaLstrRjCzqpDlWErhRd1EXWUGbf?usp=sharing"
manual_link = f"<a href=\"{MANUAL_URL}\" target=\"_blank\">📄 Open User Manual</a>"
FEEDBACK_URL = "https://forms.gle/doRsn8kk3y6po5d77"
feedback_link = f"<a href=\"{FEEDBACK_URL}\" target=\"_blank\">❤️ If you need feedback</a>"

# 🎛️ Gradio UI
if __name__ == "__main__":
    with gr.Blocks() as demo:
        gr.Markdown("## 🍄 Mushroom Safety Classifier")
        gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")

        with gr.Row():
            image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image")
            with gr.Column():
                label_en = gr.Textbox(label="🧬 Prediction (English)")
                label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
                confidence = gr.Textbox(label="📶 Confidence Score")
                label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")

        classify_btn = gr.Button("🔍 Classify")

        classify_btn.click(
            fn=classify_mushroom,
            inputs=image_input,
            outputs=[label_en, label_th, confidence, label_hint]
        )

        gr.Markdown("---")
        gr.Markdown(manual_link)  # 🆕 Display clickable user manual
        gr.Markdown(feedback_link)  # 🆕 Display clickable feedback form
        gr.Markdown("App version: 1.1.1 | Updated: August 2025")

    demo.launch()