benjaminStreltzin commited on
Commit
6976bb1
·
verified ·
1 Parent(s): df876db

Update vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +6 -4
vit_model_test.py CHANGED
@@ -7,13 +7,15 @@ from PIL import Image
7
 
8
  class CustomModel:
9
  def __init__(self):
10
- # Check for GPU availability
11
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
- # Load the pre-trained ViT model and move it to GPU
14
  self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
15
  self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
16
- self.model.load_state_dict(torch.load('trained_model.pth'))
 
 
17
  self.model.eval()
18
 
19
  # Define the image preprocessing pipeline
 
7
 
8
  class CustomModel:
9
  def __init__(self):
10
+ # Explicitly set the device to CPU
11
+ self.device = torch.device('cpu')
12
 
13
+ # Load the pre-trained ViT model and move it to CPU
14
  self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
15
  self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
16
+
17
+ # Load model weights
18
+ self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device))
19
  self.model.eval()
20
 
21
  # Define the image preprocessing pipeline