Image Classification
Transformers
English
art
benjaminStreltzin commited on
Commit
134bd4d
·
verified ·
1 Parent(s): cc4c541

Update vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +2 -1
vit_model_test.py CHANGED
@@ -29,13 +29,14 @@ if __name__ == "__main__":
29
  # Check for GPU availability
30
  device = torch.device('cuda')
31
 
 
32
  # Load the pre-trained ViT model and move it to GPU
33
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
34
 
35
 
36
 
37
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
38
- # Define the image preprocessing pipeline
39
  preprocess = transforms.Compose([
40
  transforms.Resize((224, 224)),
41
  transforms.ToTensor()
 
29
  # Check for GPU availability
30
  device = torch.device('cuda')
31
 
32
+ #this code runs only with nvidia gpu
33
  # Load the pre-trained ViT model and move it to GPU
34
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
35
 
36
 
37
 
38
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
39
+ # resize image and make it a tensor (add dimension)
40
  preprocess = transforms.Compose([
41
  transforms.Resize((224, 224)),
42
  transforms.ToTensor()