Spaces:
Sleeping
Sleeping
Update vit_model_test.py
Browse files- vit_model_test.py +6 -4
vit_model_test.py
CHANGED
@@ -7,13 +7,15 @@ from PIL import Image
|
|
7 |
|
8 |
class CustomModel:
|
9 |
def __init__(self):
|
10 |
-
#
|
11 |
-
self.device = torch.device('
|
12 |
|
13 |
-
# Load the pre-trained ViT model and move it to
|
14 |
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
|
15 |
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
|
16 |
-
|
|
|
|
|
17 |
self.model.eval()
|
18 |
|
19 |
# Define the image preprocessing pipeline
|
|
|
7 |
|
8 |
class CustomModel:
|
9 |
def __init__(self):
|
10 |
+
# Explicitly set the device to CPU
|
11 |
+
self.device = torch.device('cpu')
|
12 |
|
13 |
+
# Load the pre-trained ViT model and move it to CPU
|
14 |
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
|
15 |
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
|
16 |
+
|
17 |
+
# Load model weights
|
18 |
+
self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device))
|
19 |
self.model.eval()
|
20 |
|
21 |
# Define the image preprocessing pipeline
|