trapezius60 commited on
Commit
9219ef7
·
verified ·
1 Parent(s): f0d825c

Create app.py

Browse files

this version if apply fine-tune mushroom classification
Loads your fine-tuned PyTorch ResNet model.

Uses only two main classes: "Edible" and "Poisonous".

When the model predicts a class, it shows extra information suggesting possible mushroom species from your dataset:

Edible → "Amanita citrina", "Russula delica", "Phaeogyroporus portentosus"

Poisonous → "Amanita phalloides", "Inocybe rimosa"

Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from torchvision import models
5
+ import gradio as gr
6
+
7
+ # 🔧 Set device
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # 📦 Load your fine-tuned model
11
+ model = models.resnet50(pretrained=False)
12
+ model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
13
+ model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
14
+ model = model.to(device)
15
+ model.eval()
16
+
17
+ # 🏷️ Class names
18
+ class_names = ['Edible', 'Poisonous']
19
+
20
+ # 🍄 Mapping for more detailed species
21
+ mushroom_species = {
22
+ "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
23
+ "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
24
+ }
25
+
26
+ # 🎨 Image preprocessing (must match training)
27
+ transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize([0.485, 0.456, 0.406],
31
+ [0.229, 0.224, 0.225])
32
+ ])
33
+
34
+ # 🧠 Prediction function
35
+ def classify_mushroom(image: Image.Image):
36
+ try:
37
+ image = image.convert("RGB")
38
+ tensor = transform(image).unsqueeze(0).to(device)
39
+
40
+ with torch.no_grad():
41
+ outputs = model(tensor)
42
+ _, predicted = torch.max(outputs, 1)
43
+ label = class_names[predicted.item()]
44
+ score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
45
+
46
+ suggestion = mushroom_species[label]
47
+
48
+ return label, "กินได้" if label == "Edible" else "พิษ", f"{score:.2f}%", suggestion
49
+
50
+ except Exception as e:
51
+ print(f"❌ Error: {e}")
52
+ return "Error", "ผิดพลาด", "N/A", "N/A"
53
+
54
+ # 🎛️ Gradio UI
55
+ if __name__ == "__main__":
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("## 🍄 Mushroom Safety Classifier")
58
+ gr.Markdown("Upload a mushroom photo to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดเพื่อทำนายว่าเห็ดกินได้หรือมีพิษ")
59
+
60
+ with gr.Row():
61
+ image_input = gr.Image(type="pil", label="📷 Upload Mushroom Image")
62
+ with gr.Column():
63
+ label_en = gr.Textbox(label="🧠 Prediction (English)")
64
+ label_th = gr.Textbox(label="🗣️ คำทำนาย (ภาษาไทย)")
65
+ confidence = gr.Textbox(label="📶 Confidence Score")
66
+ label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
67
+
68
+ classify_btn = gr.Button("🔍 Classify")
69
+
70
+ classify_btn.click(
71
+ fn=classify_mushroom,
72
+ inputs=image_input,
73
+ outputs=[label_en, label_th, confidence, label_hint]
74
+ )
75
+
76
+ demo.launch()