trapezius60 commited on
Commit
5291024
·
verified ·
1 Parent(s): 31d02c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -11,7 +11,10 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # 📦 Load your fine-tuned model
13
  model = models.resnet50(weights=None) # No pretrained weights
14
- model.fc = torch.nn.Linear(model.fc.in_features, 2)
 
 
 
15
 
16
  model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
17
  model = model.to(device)
 
11
 
12
  # 📦 Load your fine-tuned model
13
  model = models.resnet50(weights=None) # No pretrained weights
14
+ model.fc = torch.nn.Sequential(
15
+ torch.nn.Dropout(0.5),
16
+ torch.nn.Linear(model.fc.in_features, 2)
17
+ )
18
 
19
  model.load_state_dict(torch.load("resnet_mushroom_classifier.pth", map_location=device))
20
  model = model.to(device)