benjaminStreltzin commited on
Commit
7fc845d
·
verified ·
1 Parent(s): 71fe305

Upload vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +95 -0
vit_model_test.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision import transforms
6
+ from transformers import ViTForImageClassification
7
+ from PIL import Image
8
+ import os
9
+ import pandas as pd
10
+
11
+
12
+
13
+
14
+ class CustomDataset(Dataset):
15
+ def __init__(self, dataframe, transform=None):
16
+ self.dataframe = dataframe
17
+ self.transform = transform
18
+
19
+ def __len__(self):
20
+ return len(self.dataframe)
21
+
22
+ def __getitem__(self, idx):
23
+ image_path = self.dataframe.iloc[idx, 0] # Image path is in the first column
24
+ image = Image.open(image_path).convert('RGB') # Convert to RGB format
25
+
26
+ if self.transform:
27
+ image = self.transform(image)
28
+
29
+ return image
30
+
31
+
32
+ if __name__ == "__main__":
33
+ # Check for GPU availability
34
+ device = torch.device('cuda')
35
+
36
+ # Load the pre-trained ViT model and move it to GPU
37
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
38
+
39
+
40
+
41
+ model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
42
+ # Define the image preprocessing pipeline
43
+ preprocess = transforms.Compose([
44
+ transforms.Resize((224, 224)),
45
+ transforms.ToTensor()
46
+ ])
47
+
48
+
49
+
50
+
51
+
52
+ # Load the test dataset
53
+
54
+
55
+ ### need to recive image from gratio/streamlit
56
+
57
+ test_set = 'datasets/'
58
+
59
+ image_paths = []
60
+ for filename in os.listdir(test_set):
61
+ image_paths.append(os.path.join(test_set, filename))
62
+ dataset = pd.DataFrame({'image_path': image_paths})
63
+
64
+
65
+
66
+ test_dataset = CustomDataset(dataset, transform=preprocess)
67
+ test_loader = DataLoader(test_dataset, batch_size=32)
68
+
69
+ # Load the trained model
70
+ model.load_state_dict(torch.load('trained_model.pth'))
71
+
72
+ # Evaluate the model
73
+ model.eval()
74
+ confidences = []
75
+ predicted_labels = []
76
+
77
+
78
+ with torch.no_grad():
79
+ for images in test_loader:
80
+ images = images.to(device)
81
+ outputs = model(images)
82
+ logits = outputs.logits # Extract logits from the output
83
+ probabilities = F.softmax(logits, dim=1)
84
+ confidences_per_image, predicted = torch.max(probabilities, 1)
85
+ predicted_labels.extend(predicted.cpu().numpy())
86
+ confidences.extend(confidences_per_image.cpu().numpy())
87
+
88
+
89
+ print(predicted_labels)
90
+ print(confidences)
91
+
92
+ confidence_percentages = [confidence * 100 for confidence in confidences]
93
+ for label, confidence in zip(predicted_labels, confidence_percentages):
94
+ print(f"Predicted label: {label}, Confidence: {confidence:.2f}%")
95
+