trapezius60's picture
Update app.py
273f35a verified
raw
history blame
4.17 kB
# ✅ Import required libraries
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import models
import gradio as gr
from rembg import remove # Background removal
from transformers import pipeline # For non-mushroom detection
# 🔧 Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 📦 Load your fine-tuned model
model = models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
model = model.to(device)
model.eval()
# 🏷️ Class names and species mapping
class_names = ['Edible', 'Poisonous']
mushroom_species = {
"Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
"Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
}
# 🎨 Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 🔍 Pretrained image classifier for screening non-mushroom
label_detector = pipeline("image-classification", model="microsoft/resnet-50")
mushroom_keywords = ["mushroom", "agaric", "amanita", "fungus", "earthstar", "toadstool"]
def is_mushroom_image(image):
try:
result = label_detector(image)
top_label = result[0]["label"].lower()
return any(keyword in top_label for keyword in mushroom_keywords)
except:
return False
# 🧠 Classification function with validation
CONFIDENCE_THRESHOLD = 85.0 # Minimum confidence considered safe enough to show suggestion
def classify_mushroom(image: Image.Image):
try:
if not is_mushroom_image(image):
return "Not a mushroom", "ไม่ใช่เห็ด", "0.00%", "Please upload a mushroom photo."
image = image.convert("RGBA")
image_no_bg = remove(image).convert("RGB")
tensor = transform(image_no_bg).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(tensor)
_, predicted = torch.max(outputs, 1)
label = class_names[predicted.item()]
score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
suggestion = mushroom_species[label] if score >= CONFIDENCE_THRESHOLD else "Confidence too low to suggest species."
if score < CONFIDENCE_THRESHOLD:
label = "Uncertain"
return label, ("กินได้" if label == "Edible" else ("พิษ" if label == "Poisonous" else "ไม่มั่นใจ")), f"{score:.2f}%", suggestion
except Exception as e:
print(f"❌ Error: {e}")
return "Error", "ผิดพลาด", "N/A", "Invalid image. Please upload a valid mushroom photo."
# 🎛️ Gradio UI
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("## 🍄 Mushroom Safety Classifier")
gr.Markdown("Upload a mushroom photo or use your camera to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดหรือใช้กล้องเพื่อตรวจสอบว่าเห็ดกินได้หรือมีพิษ")
with gr.Row():
image_input = gr.Image(type="pil", label="📷 Upload or Capture Mushroom Image")
with gr.Column():
label_en = gr.Textbox(label="🧬 Prediction (English)")
label_th = gr.Textbox(label="🔁 คำทำนาย (ภาษาไทย)")
confidence = gr.Textbox(label="📶 Confidence Score")
label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
classify_btn = gr.Button("🔍 Classify")
classify_btn.click(
fn=classify_mushroom,
inputs=image_input,
outputs=[label_en, label_th, confidence, label_hint]
)
gr.Markdown("---")
gr.Markdown("App version: 1.1.0 | Updated: August 2025")
demo.launch()