trapezius60 commited on
Commit
ae7b3ca
Β·
verified Β·
1 Parent(s): 4e16063

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -6,13 +6,17 @@ from torchvision import models
6
  import gradio as gr
7
  from rembg import remove # Background removal
8
  from transformers import pipeline # For non-mushroom detection
 
9
 
10
  # πŸ”§ Set device
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  # πŸ“¦ Load your fine-tuned model
14
- model = models.resnet50(pretrained=False)
15
- model.fc = torch.nn.Linear(model.fc.in_features, 2) # 2 classes: Edible, Poisonous
 
 
 
16
  model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
17
  model = model.to(device)
18
  model.eval()
 
6
  import gradio as gr
7
  from rembg import remove # Background removal
8
  from transformers import pipeline # For non-mushroom detection
9
+ from torchvision.models import ResNet50_Weights
10
 
11
  # πŸ”§ Set device
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  # πŸ“¦ Load your fine-tuned model
15
+ model = models.resnet50(weights=None) # No pretrained weights
16
+ model.fc = torch.nn.Sequential(
17
+ torch.nn.Dropout(0.5),
18
+ torch.nn.Linear(model.fc.in_features, 2)
19
+ )
20
  model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
21
  model = model.to(device)
22
  model.eval()