Spaces:
Build error
Build error
File size: 3,448 Bytes
21135f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import models
import gradio as gr
from rembg import remove # 🆕 Background removal
import io
import datetime
# 🔧 Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 📦 Load your fine-tuned model
model = models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
model = model.to(device)
model.eval()
# 🏷️ Class names
class_names = ['Edible', 'Poisonous']
# 🍄 Mapping for more detailed species
mushroom_species = {
"Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
"Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
}
# 🎨 Image preprocessing (must match training)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 🧠 Prediction function
def classify_mushroom(image: Image.Image):
try:
image = image.convert("RGB")
# 🆕 Remove background
bg_removed = remove(image)
# 🆗 Prepare cleaned image for model
tensor = transform(bg_removed).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(tensor)
_, predicted = torch.max(outputs, 1)
label = class_names[predicted.item()]
score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
suggestion = mushroom_species[label]
return label, "กินได้" if label == "Edible" else "พิษ", f"{score:.2f}%", suggestion, bg_removed
except Exception as e:
print(f"❌ Error: {e}")
return "Error", "ผิดพลาด", "N/A", "N/A", image
# 📆 Version Info
version_info = "**Version:** 1.0.1 \n**Last updated:** " + datetime.datetime.now().strftime("%Y-%m-%d")
# 🎛️ Gradio UI
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("## 🍄 Mushroom Safety Classifier")
gr.Markdown("Upload a mushroom photo to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดเพื่อทำนายว่าเห็ดกินได้หรือมีพิษ")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="📷 Upload Mushroom Image")
cleaned_image = gr.Image(label="🧼 Background Removed Image") # 🆕 Show cleaned image
with gr.Column():
label_en = gr.Textbox(label="🧠 Prediction (English)")
label_th = gr.Textbox(label="🗣️ คำทำนาย (ภาษาไทย)")
confidence = gr.Textbox(label="📶 Confidence Score")
label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
classify_btn = gr.Button("🔍 Classify")
classify_btn.click(
fn=classify_mushroom,
inputs=image_input,
outputs=[label_en, label_th, confidence, label_hint, cleaned_image]
)
gr.Markdown("---")
gr.Markdown(version_info) # 📅 Version and update info
demo.launch()
|