Spaces:
Build error
Build error
File size: 4,786 Bytes
bf3c131 21135f4 273f35a 285ba5e ae7b3ca 21135f4 285ba5e ae7b3ca 21135f4 bf3c131 21135f4 bf3c131 d2d1936 21135f4 285ba5e bf3c131 21135f4 285ba5e e3a1380 21135f4 bf3c131 21135f4 bf3c131 d2d1936 21135f4 bf3c131 21135f4 5c573bd 21135f4 bf3c131 21135f4 285ba5e c295347 bf3c131 21135f4 bf3c131 21135f4 5c573bd 285ba5e 21135f4 285ba5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
# ✅ Import required libraries
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import models
import gradio as gr
from rembg import remove # Background removal
from transformers import pipeline # For non-mushroom detection
from torchvision.models import ResNet50_Weights
# 🔧 Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 📦 Load your fine-tuned model
model = models.resnet50(weights=None) # No pretrained weights
model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.5),
torch.nn.Linear(model.fc.in_features, 2)
)
model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
model = model.to(device)
model.eval()
# 🏷️ Class names and species mapping
class_names = ['Edible', 'Poisonous']
mushroom_species = {
"Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
"Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
}
# 🎨 Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 🔍 Pretrained image classifier for screening non-mushroom
label_detector = pipeline("image-classification", model="microsoft/resnet-50")
mushroom_keywords = ["mushroom", "agaric", "amanita", "fungus", "earthstar", "toadstool"]
def is_mushroom_image(image):
try:
result = label_detector(image)
top_label = result[0]["label"].lower()
return any(keyword in top_label for keyword in mushroom_keywords)
except:
return False
# 🧠 Classification function with validation
CONFIDENCE_THRESHOLD = 85.0 # Minimum confidence considered safe enough to show suggestion
def classify_mushroom(image: Image.Image):
try:
if not is_mushroom_image(image):
return "Not a mushroom", "ไม่ใช่เห็ด", "0.00%", "Please upload a mushroom photo."
image = image.convert("RGBA")
image_no_bg = remove(image).convert("RGB")
tensor = transform(image_no_bg).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(tensor)
_, predicted = torch.max(outputs, 1)
label = class_names[predicted.item()]
score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."
if score < CONFIDENCE_THRESHOLD:
label = "Uncertain"
return label, ("กินได้" if label == "Edible" else ("พิษ" if label == "Poisonous" else "ไม่มั่นใจ")), f"{score:.2f}%", suggestion
except Exception as e:
print(f"❌ Error: {e}")
return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."
# 🔗 Open user manual (link version) 🆕
MANUAL_URL = "https://drive.google.com/drive/folders/19lUCEaLstrRjCzqpDlWErhRd1EXWUGbf?usp=sharing"
manual_link = f"<a href=\"{MANUAL_URL}\" target=\"_blank\">📄 Open User Manual</a>"
FEEDBACK_URL = "https://forms.gle/k5zE2xoUudzjqqS29"
feedback_link = f"<a href=\"{FEEDBACK_URL}\" target=\"_blank\">📄 If you need feedback</a>"
# 🎛️ Gradio UI
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("## 🍄 Mushroom Safety Classifier")
gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")
with gr.Row():
image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image")
with gr.Column():
label_en = gr.Textbox(label="🧬 Prediction (English)")
label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
confidence = gr.Textbox(label="📶 Confidence Score")
label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
classify_btn = gr.Button("🔍 Classify")
classify_btn.click(
fn=classify_mushroom,
inputs=image_input,
outputs=[label_en, label_th, confidence, label_hint]
)
gr.Markdown("---")
gr.Markdown(manual_link) # 🆕 Display clickable user manual
gr.Markdown(feedback_link) # 🆕 Display clickable user manual
gr.Markdown("App version: 1.1.0 | Updated: August 2025")
demo.launch()
|