trapezius60's picture
Update app.py
21135f4 verified
raw
history blame
3.45 kB
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import models
import gradio as gr
from rembg import remove # 🆕 Background removal
import io
import datetime
# 🔧 Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 📦 Load your fine-tuned model
model = models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
model = model.to(device)
model.eval()
# 🏷️ Class names
class_names = ['Edible', 'Poisonous']
# 🍄 Mapping for more detailed species
mushroom_species = {
"Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
"Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
}
# 🎨 Image preprocessing (must match training)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 🧠 Prediction function
def classify_mushroom(image: Image.Image):
try:
image = image.convert("RGB")
# 🆕 Remove background
bg_removed = remove(image)
# 🆗 Prepare cleaned image for model
tensor = transform(bg_removed).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(tensor)
_, predicted = torch.max(outputs, 1)
label = class_names[predicted.item()]
score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
suggestion = mushroom_species[label]
return label, "กินได้" if label == "Edible" else "พิษ", f"{score:.2f}%", suggestion, bg_removed
except Exception as e:
print(f"❌ Error: {e}")
return "Error", "ผิดพลาด", "N/A", "N/A", image
# 📆 Version Info
version_info = "**Version:** 1.0.1 \n**Last updated:** " + datetime.datetime.now().strftime("%Y-%m-%d")
# 🎛️ Gradio UI
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("## 🍄 Mushroom Safety Classifier")
gr.Markdown("Upload a mushroom photo to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดเพื่อทำนายว่าเห็ดกินได้หรือมีพิษ")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="📷 Upload Mushroom Image")
cleaned_image = gr.Image(label="🧼 Background Removed Image") # 🆕 Show cleaned image
with gr.Column():
label_en = gr.Textbox(label="🧠 Prediction (English)")
label_th = gr.Textbox(label="🗣️ คำทำนาย (ภาษาไทย)")
confidence = gr.Textbox(label="📶 Confidence Score")
label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
classify_btn = gr.Button("🔍 Classify")
classify_btn.click(
fn=classify_mushroom,
inputs=image_input,
outputs=[label_en, label_th, confidence, label_hint, cleaned_image]
)
gr.Markdown("---")
gr.Markdown(version_info) # 📅 Version and update info
demo.launch()