Spaces:
Build error
Build error
File size: 3,896 Bytes
bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 bf3c131 21135f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# ✅ Import required libraries
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import models
import gradio as gr
import webbrowser
# 🔧 Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 📦 Load the trained model
model = models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
model = model.to(device)
model.eval()
# 🏷️ Class names and species mapping
class_names = ['Edible', 'Poisonous']
mushroom_species = {
"Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
"Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
}
# 🎨 Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 🧠 Classification function with validation
CONFIDENCE_THRESHOLD = 85.0 # Minimum confidence considered safe enough to show suggestion
def classify_mushroom(image: Image.Image):
try:
image = image.convert("RGB")
tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(tensor)
_, predicted = torch.max(outputs, 1)
label = class_names[predicted.item()]
score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."
# Check confidence and warn
if score < CONFIDENCE_THRESHOLD:
label = "Uncertain"
return label, ("กินได้" if label == "Edible" else ("พิษ" if label == "Poisonous" else "ไม่มั่นใจ")), f"{score:.2f}%", suggestion
except Exception as e:
print(f"❌ Error: {e}")
return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."
# 🔗 Open user manual
MANUAL_URL = "https://drive.google.com/drive/folders/19lUCEaLstrRjCzqpDlWErhRd1EXWUGbf?usp=sharing"
def open_manual():
webbrowser.open(MANUAL_URL)
return "", "", "", "Opening user manual..."
# 🎛️ Gradio UI
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("## 🍄 Mushroom Safety Classifier")
gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")
with gr.Row():
image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image", source="upload", tool="editor")
with gr.Column():
label_en = gr.Textbox(label="🧬 Prediction (English)")
label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
confidence = gr.Textbox(label="📶 Confidence Score")
label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
classify_btn = gr.Button("🔍 Classify")
manual_btn = gr.Button("Open User Manual 📄")
classify_btn.click(
fn=classify_mushroom,
inputs=image_input,
outputs=[label_en, label_th, confidence, label_hint]
)
manual_btn.click(
fn=open_manual,
inputs=[],
outputs=[label_en, label_th, confidence, label_hint]
)
gr.Markdown("---")
gr.Markdown("App version: 1.0.1 | Updated: August 2025")
demo.launch()
|