VIT_Demo / vit_model_test.py
benjaminStreltzin's picture
Upload vit_model_test.py
7fc845d verified
raw
history blame
2.73 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
from PIL import Image
import os
import pandas as pd
class CustomDataset(Dataset):
def __init__(self, dataframe, transform=None):
self.dataframe = dataframe
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
image_path = self.dataframe.iloc[idx, 0] # Image path is in the first column
image = Image.open(image_path).convert('RGB') # Convert to RGB format
if self.transform:
image = self.transform(image)
return image
if __name__ == "__main__":
# Check for GPU availability
device = torch.device('cuda')
# Load the pre-trained ViT model and move it to GPU
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
# Define the image preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Load the test dataset
### need to recive image from gratio/streamlit
test_set = 'datasets/'
image_paths = []
for filename in os.listdir(test_set):
image_paths.append(os.path.join(test_set, filename))
dataset = pd.DataFrame({'image_path': image_paths})
test_dataset = CustomDataset(dataset, transform=preprocess)
test_loader = DataLoader(test_dataset, batch_size=32)
# Load the trained model
model.load_state_dict(torch.load('trained_model.pth'))
# Evaluate the model
model.eval()
confidences = []
predicted_labels = []
with torch.no_grad():
for images in test_loader:
images = images.to(device)
outputs = model(images)
logits = outputs.logits # Extract logits from the output
probabilities = F.softmax(logits, dim=1)
confidences_per_image, predicted = torch.max(probabilities, 1)
predicted_labels.extend(predicted.cpu().numpy())
confidences.extend(confidences_per_image.cpu().numpy())
print(predicted_labels)
print(confidences)
confidence_percentages = [confidence * 100 for confidence in confidences]
for label, confidence in zip(predicted_labels, confidence_percentages):
print(f"Predicted label: {label}, Confidence: {confidence:.2f}%")