# ✅ Import required libraries from PIL import Image import torch import torchvision.transforms as transforms from torchvision import models import gradio as gr from rembg import remove # Background removal # 🔧 Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 📦 Load your fine-tuned model model = models.resnet50(weights=None) # No pretrained weights model.fc = torch.nn.Sequential( torch.nn.Dropout(0.5), torch.nn.Linear(model.fc.in_features, 2) ) model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device)) model = model.to(device) model.eval() # 🏷️ Class names and species mapping class_names = ['Edible', 'Poisonous'] mushroom_species = { "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus", "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa" } # 🎨 Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 🧠 Classification function with confidence filtering CONFIDENCE_THRESHOLD = 85.0 # Minimum confidence considered safe enough to show suggestion def classify_mushroom(image: Image.Image): try: image = image.convert("RGBA") image_no_bg = remove(image).convert("RGB") tensor = transform(image_no_bg).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(tensor) _, predicted = torch.max(outputs, 1) label = class_names[predicted.item()] score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100 suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species." if score < CONFIDENCE_THRESHOLD: label = "Uncertain" return label, ("กินได้" if label == "Edible" else ("พิษ" if label == "Poisonous" else "ไม่มั่นใจ")), f"{score:.2f}%", suggestion except Exception as e: print(f"❌ Error: {e}") return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo." # 🔗 Open user manual (link version) 🆕 MANUAL_URL = "https://drive.google.com/drive/folders/19lUCEaLstrRjCzqpDlWErhRd1EXWUGbf?usp=sharing" manual_link = f"📄 Open User Manual" FEEDBACK_URL = "https://forms.gle/doRsn8kk3y6po5d77" feedback_link = f"❤️ If you need feedback" # 🎛️ Gradio UI if __name__ == "__main__": with gr.Blocks() as demo: gr.Markdown("## 🍄 Mushroom Safety Classifier") gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ") with gr.Row(): image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image") with gr.Column(): label_en = gr.Textbox(label="🧬 Prediction (English)") label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)") confidence = gr.Textbox(label="📶 Confidence Score") label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)") classify_btn = gr.Button("🔍 Classify") classify_btn.click( fn=classify_mushroom, inputs=image_input, outputs=[label_en, label_th, confidence, label_hint] ) gr.Markdown("---") gr.Markdown(manual_link) # 🆕 Display clickable user manual gr.Markdown(feedback_link) # 🆕 Display clickable feedback form gr.Markdown("App version: 1.1.1 | Updated: August 2025") demo.launch()