Update vit_model_test.py
Browse files- vit_model_test.py +2 -1
vit_model_test.py
CHANGED
@@ -29,13 +29,14 @@ if __name__ == "__main__":
|
|
29 |
# Check for GPU availability
|
30 |
device = torch.device('cuda')
|
31 |
|
|
|
32 |
# Load the pre-trained ViT model and move it to GPU
|
33 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
|
34 |
|
35 |
|
36 |
|
37 |
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
|
38 |
-
#
|
39 |
preprocess = transforms.Compose([
|
40 |
transforms.Resize((224, 224)),
|
41 |
transforms.ToTensor()
|
|
|
29 |
# Check for GPU availability
|
30 |
device = torch.device('cuda')
|
31 |
|
32 |
+
#this code runs only with nvidia gpu
|
33 |
# Load the pre-trained ViT model and move it to GPU
|
34 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
|
35 |
|
36 |
|
37 |
|
38 |
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
|
39 |
+
# resize image and make it a tensor (add dimension)
|
40 |
preprocess = transforms.Compose([
|
41 |
transforms.Resize((224, 224)),
|
42 |
transforms.ToTensor()
|