benjaminStreltzin commited on
Commit
dbc65b9
·
verified ·
1 Parent(s): 333fd10

Update vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +1 -1
vit_model_test.py CHANGED
@@ -15,7 +15,7 @@ class CustomModel:
15
  self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
16
 
17
  # Load model weights
18
- self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device))
19
  self.model.eval()
20
 
21
  # Define the image preprocessing pipeline
 
15
  self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
16
 
17
  # Load model weights
18
+ self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True))
19
  self.model.eval()
20
 
21
  # Define the image preprocessing pipeline