trapezius60 commited on
Commit
21135f4
·
verified ·
1 Parent(s): 0a7d2fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -82
app.py CHANGED
@@ -1,82 +1,92 @@
1
- from PIL import Image
2
- import torch
3
- import torchvision.transforms as transforms
4
- from torchvision import models
5
- import gradio as gr
6
- from rembg import remove # 🆕 Background removal
7
-
8
- # 🔧 Set device
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
-
11
- # 📦 Load your fine-tuned model
12
- model = models.resnet50(pretrained=False)
13
- model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
14
- model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
15
- model = model.to(device)
16
- model.eval()
17
-
18
- # 🏷️ Class names
19
- class_names = ['Edible', 'Poisonous']
20
-
21
- # 🍄 Mapping for more detailed species
22
- mushroom_species = {
23
- "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
24
- "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
25
- }
26
-
27
- # 🎨 Image preprocessing (must match training)
28
- transform = transforms.Compose([
29
- transforms.Resize((224, 224)),
30
- transforms.ToTensor(),
31
- transforms.Normalize([0.485, 0.456, 0.406],
32
- [0.229, 0.224, 0.225])
33
- ])
34
-
35
- # 🧠 Prediction function with background removal
36
- def classify_mushroom(image: Image.Image):
37
- try:
38
- image = image.convert("RGB")
39
-
40
- # 🆕 Remove background
41
- image = remove(image) # returns RGBA with transparency
42
- image = image.convert("RGB") # back to RGB
43
-
44
- tensor = transform(image).unsqueeze(0).to(device)
45
-
46
- with torch.no_grad():
47
- outputs = model(tensor)
48
- _, predicted = torch.max(outputs, 1)
49
- label = class_names[predicted.item()]
50
- score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
51
-
52
- suggestion = mushroom_species[label]
53
-
54
- return label, "กินได้" if label == "Edible" else "พิษ", f"{score:.2f}%", suggestion
55
-
56
- except Exception as e:
57
- print(f"❌ Error: {e}")
58
- return "Error", "ผิดพลาด", "N/A", "N/A"
59
-
60
- # 🎛️ Gradio UI
61
- if __name__ == "__main__":
62
- with gr.Blocks() as demo:
63
- gr.Markdown("## 🍄 Mushroom Safety Classifier")
64
- gr.Markdown("Upload a mushroom photo to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดเพื่อทำนายว่าเห็ดกินได้หรือมีพิษ")
65
-
66
- with gr.Row():
67
- image_input = gr.Image(type="pil", label="📷 Upload Mushroom Image")
68
- with gr.Column():
69
- label_en = gr.Textbox(label="🧠 Prediction (English)")
70
- label_th = gr.Textbox(label="🗣️ คำทำนาย (ภาษาไทย)")
71
- confidence = gr.Textbox(label="📶 Confidence Score")
72
- label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
73
-
74
- classify_btn = gr.Button("🔍 Classify")
75
-
76
- classify_btn.click(
77
- fn=classify_mushroom,
78
- inputs=image_input,
79
- outputs=[label_en, label_th, confidence, label_hint]
80
- )
81
-
82
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from torchvision import models
5
+ import gradio as gr
6
+ from rembg import remove # 🆕 Background removal
7
+ import io
8
+ import datetime
9
+
10
+ # 🔧 Set device
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # 📦 Load your fine-tuned model
14
+ model = models.resnet50(pretrained=False)
15
+ model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
16
+ model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
17
+ model = model.to(device)
18
+ model.eval()
19
+
20
+ # 🏷️ Class names
21
+ class_names = ['Edible', 'Poisonous']
22
+
23
+ # 🍄 Mapping for more detailed species
24
+ mushroom_species = {
25
+ "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
26
+ "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
27
+ }
28
+
29
+ # 🎨 Image preprocessing (must match training)
30
+ transform = transforms.Compose([
31
+ transforms.Resize((224, 224)),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize([0.485, 0.456, 0.406],
34
+ [0.229, 0.224, 0.225])
35
+ ])
36
+
37
+ # 🧠 Prediction function
38
+ def classify_mushroom(image: Image.Image):
39
+ try:
40
+ image = image.convert("RGB")
41
+
42
+ # 🆕 Remove background
43
+ bg_removed = remove(image)
44
+
45
+ # 🆗 Prepare cleaned image for model
46
+ tensor = transform(bg_removed).unsqueeze(0).to(device)
47
+
48
+ with torch.no_grad():
49
+ outputs = model(tensor)
50
+ _, predicted = torch.max(outputs, 1)
51
+ label = class_names[predicted.item()]
52
+ score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
53
+
54
+ suggestion = mushroom_species[label]
55
+
56
+ return label, "กินได้" if label == "Edible" else "พิษ", f"{score:.2f}%", suggestion, bg_removed
57
+
58
+ except Exception as e:
59
+ print(f"❌ Error: {e}")
60
+ return "Error", "ผิดพลาด", "N/A", "N/A", image
61
+
62
+ # 📆 Version Info
63
+ version_info = "**Version:** 1.0.1 \n**Last updated:** " + datetime.datetime.now().strftime("%Y-%m-%d")
64
+
65
+ # 🎛️ Gradio UI
66
+ if __name__ == "__main__":
67
+ with gr.Blocks() as demo:
68
+ gr.Markdown("## 🍄 Mushroom Safety Classifier")
69
+ gr.Markdown("Upload a mushroom photo to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดเพื่อทำนายว่าเห็ดกินได้หรือมีพิษ")
70
+
71
+ with gr.Row():
72
+ with gr.Column():
73
+ image_input = gr.Image(type="pil", label="📷 Upload Mushroom Image")
74
+ cleaned_image = gr.Image(label="🧼 Background Removed Image") # 🆕 Show cleaned image
75
+ with gr.Column():
76
+ label_en = gr.Textbox(label="🧠 Prediction (English)")
77
+ label_th = gr.Textbox(label="🗣️ คำทำนาย (ภาษาไทย)")
78
+ confidence = gr.Textbox(label="📶 Confidence Score")
79
+ label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
80
+
81
+ classify_btn = gr.Button("🔍 Classify")
82
+
83
+ classify_btn.click(
84
+ fn=classify_mushroom,
85
+ inputs=image_input,
86
+ outputs=[label_en, label_th, confidence, label_hint, cleaned_image]
87
+ )
88
+
89
+ gr.Markdown("---")
90
+ gr.Markdown(version_info) # 📅 Version and update info
91
+
92
+ demo.launch()