File size: 4,371 Bytes
bf3c131
21135f4
 
 
 
 
bf3c131
e3a1380
21135f4
 
 
 
bf3c131
21135f4
 
 
 
 
 
bf3c131
21135f4
 
 
 
 
 
bf3c131
4f32aae
 
 
 
 
 
 
 
 
21135f4
 
 
 
 
bf3c131
 
 
21135f4
 
e3a1380
 
 
21135f4
 
 
 
 
 
 
bf3c131
21135f4
bf3c131
 
 
 
 
21135f4
 
 
bf3c131
21135f4
4f32aae
bf3c131
4f32aae
21135f4
 
 
 
 
bf3c131
21135f4
 
4f32aae
21135f4
bf3c131
 
21135f4
 
 
 
 
 
 
 
bf3c131
 
 
21135f4
4f32aae
 
21135f4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# ✅ Import required libraries
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import models
import gradio as gr
import webbrowser
from rembg import remove  # 🆕 Background removal

# 🔧 Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 📦 Load the trained model
model = models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2)  # 2 classes: Edible, Poisonous
model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
model = model.to(device)
model.eval()

# 🏷️ Class names and species mapping
class_names = ['Edible', 'Poisonous']
mushroom_species = {
    "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
    "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
}

# 🎨 Image preprocessing
# ✅ Define transforms for training and validation
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(degrees=45),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# 🧠 Classification function with validation
CONFIDENCE_THRESHOLD = 85.0  # Minimum confidence considered safe enough to show suggestion

def classify_mushroom(image: Image.Image):
    try:
        image = image.convert("RGBA")  # Required format for rembg
        image_no_bg = remove(image).convert("RGB")  # 🆕 Remove background
        tensor = transform(image_no_bg).unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model(tensor)
            _, predicted = torch.max(outputs, 1)
            label = class_names[predicted.item()]
            score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100

        suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."

        # Check confidence and warn
        if score < CONFIDENCE_THRESHOLD:
            label = "Uncertain"

        return label, ("กินได้" if label == "Edible" else ("พิษ" if label == "Poisonous" else "ไม่มั่นใจ")), f"{score:.2f}%", suggestion

    except Exception as e:
        print(f"❌ Error: {e}")
        return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."

# 🔗 Open user manual (link version) 🆕
MANUAL_URL = "https://drive.google.com/drive/folders/19lUCEaLstrRjCzqpDlWErhRd1EXWUGbf?usp=sharing"
manual_link = f"<a href=\"{MANUAL_URL}\" target=\"_blank\">📄 Open User Manual</a>"

# 🎛️ Gradio UI
if __name__ == "__main__":
    with gr.Blocks() as demo:
        gr.Markdown("## 🍄 Mushroom Safety Classifier")
        gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")

        with gr.Row():
            image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image", source="upload")  # 🆕 safer source usage
            with gr.Column():
                label_en = gr.Textbox(label="🧬 Prediction (English)")
                label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
                confidence = gr.Textbox(label="📶 Confidence Score")
                label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")

        classify_btn = gr.Button("🔍 Classify")

        classify_btn.click(
            fn=classify_mushroom,
            inputs=image_input,
            outputs=[label_en, label_th, confidence, label_hint]
        )

        gr.Markdown("---")
        gr.Markdown(manual_link)  # 🆕 Display clickable user manual
        gr.Markdown("App version: 1.0.2 | Updated: August 2025")  # 🆕 Version bump

    demo.launch()