trapezius60 commited on
Commit
6b484bd
·
verified ·
1 Parent(s): 3f876b8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -76
app.py CHANGED
@@ -1,76 +1,82 @@
1
- from PIL import Image
2
- import torch
3
- import torchvision.transforms as transforms
4
- from torchvision import models
5
- import gradio as gr
6
-
7
- # 🔧 Set device
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
-
10
- # 📦 Load your fine-tuned model
11
- model = models.resnet50(pretrained=False)
12
- model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
13
- model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
14
- model = model.to(device)
15
- model.eval()
16
-
17
- # 🏷️ Class names
18
- class_names = ['Edible', 'Poisonous']
19
-
20
- # 🍄 Mapping for more detailed species
21
- mushroom_species = {
22
- "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
23
- "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
24
- }
25
-
26
- # 🎨 Image preprocessing (must match training)
27
- transform = transforms.Compose([
28
- transforms.Resize((224, 224)),
29
- transforms.ToTensor(),
30
- transforms.Normalize([0.485, 0.456, 0.406],
31
- [0.229, 0.224, 0.225])
32
- ])
33
-
34
- # 🧠 Prediction function
35
- def classify_mushroom(image: Image.Image):
36
- try:
37
- image = image.convert("RGB")
38
- tensor = transform(image).unsqueeze(0).to(device)
39
-
40
- with torch.no_grad():
41
- outputs = model(tensor)
42
- _, predicted = torch.max(outputs, 1)
43
- label = class_names[predicted.item()]
44
- score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
45
-
46
- suggestion = mushroom_species[label]
47
-
48
- return label, "กินได้" if label == "Edible" else "พิษ", f"{score:.2f}%", suggestion
49
-
50
- except Exception as e:
51
- print(f"❌ Error: {e}")
52
- return "Error", "ผิดพลาด", "N/A", "N/A"
53
-
54
- # 🎛️ Gradio UI
55
- if __name__ == "__main__":
56
- with gr.Blocks() as demo:
57
- gr.Markdown("## 🍄 Mushroom Safety Classifier")
58
- gr.Markdown("Upload a mushroom photo to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดเพื่อทำนายว่าเห็ดกินได้หรือมีพิษ")
59
-
60
- with gr.Row():
61
- image_input = gr.Image(type="pil", label="📷 Upload Mushroom Image")
62
- with gr.Column():
63
- label_en = gr.Textbox(label="🧠 Prediction (English)")
64
- label_th = gr.Textbox(label="🗣️ คำทำนาย (ภาษาไทย)")
65
- confidence = gr.Textbox(label="📶 Confidence Score")
66
- label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
67
-
68
- classify_btn = gr.Button("🔍 Classify")
69
-
70
- classify_btn.click(
71
- fn=classify_mushroom,
72
- inputs=image_input,
73
- outputs=[label_en, label_th, confidence, label_hint]
74
- )
75
-
76
- demo.launch()
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from torchvision import models
5
+ import gradio as gr
6
+ from rembg import remove # 🆕 Background removal
7
+
8
+ # 🔧 Set device
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # 📦 Load your fine-tuned model
12
+ model = models.resnet50(pretrained=False)
13
+ model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
14
+ model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
15
+ model = model.to(device)
16
+ model.eval()
17
+
18
+ # 🏷️ Class names
19
+ class_names = ['Edible', 'Poisonous']
20
+
21
+ # 🍄 Mapping for more detailed species
22
+ mushroom_species = {
23
+ "Edible": "Possible species:\n• Amanita citrina\n• Russula delica\n• Phaeogyroporus portentosus",
24
+ "Poisonous": "Possible species:\n• Amanita phalloides\n• Inocybe rimosa"
25
+ }
26
+
27
+ # 🎨 Image preprocessing (must match training)
28
+ transform = transforms.Compose([
29
+ transforms.Resize((224, 224)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.485, 0.456, 0.406],
32
+ [0.229, 0.224, 0.225])
33
+ ])
34
+
35
+ # 🧠 Prediction function with background removal
36
+ def classify_mushroom(image: Image.Image):
37
+ try:
38
+ image = image.convert("RGB")
39
+
40
+ # 🆕 Remove background
41
+ image = remove(image) # returns RGBA with transparency
42
+ image = image.convert("RGB") # back to RGB
43
+
44
+ tensor = transform(image).unsqueeze(0).to(device)
45
+
46
+ with torch.no_grad():
47
+ outputs = model(tensor)
48
+ _, predicted = torch.max(outputs, 1)
49
+ label = class_names[predicted.item()]
50
+ score = torch.softmax(outputs, dim=1)[0][predicted.item()].item() * 100
51
+
52
+ suggestion = mushroom_species[label]
53
+
54
+ return label, "กินได้" if label == "Edible" else "พิษ", f"{score:.2f}%", suggestion
55
+
56
+ except Exception as e:
57
+ print(f" Error: {e}")
58
+ return "Error", "ผิดพลาด", "N/A", "N/A"
59
+
60
+ # 🎛️ Gradio UI
61
+ if __name__ == "__main__":
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("## 🍄 Mushroom Safety Classifier")
64
+ gr.Markdown("Upload a mushroom photo to check if it’s edible or poisonous.\nอัปโหลดรูปเห็ดเพื่อทำนายว่าเห็ดกินได้หรือมีพิษ")
65
+
66
+ with gr.Row():
67
+ image_input = gr.Image(type="pil", label="📷 Upload Mushroom Image")
68
+ with gr.Column():
69
+ label_en = gr.Textbox(label="🧠 Prediction (English)")
70
+ label_th = gr.Textbox(label="🗣️ คำทำนาย (ภาษาไทย)")
71
+ confidence = gr.Textbox(label="📶 Confidence Score")
72
+ label_hint = gr.Textbox(label="🏷️ Likely Species (Based on Training Data)")
73
+
74
+ classify_btn = gr.Button("🔍 Classify")
75
+
76
+ classify_btn.click(
77
+ fn=classify_mushroom,
78
+ inputs=image_input,
79
+ outputs=[label_en, label_th, confidence, label_hint]
80
+ )
81
+
82
+ demo.launch()