VIT_Demo / vit_model_test.py
benjaminStreltzin's picture
Update vit_model_test.py
22aae95 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import ViTForImageClassification
from PIL import Image
class CustomModel:
def __init__(self):
# Explicitly set the device to CPU
self.device = torch.device('cpu')
# Load the pre-trained ViT
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
# Load model weights
self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True))
self.model.eval()
# Resize the image and make it a tensor (add dimension)
self.preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def predict(self, image: Image.Image):
# Preprocess the image
image = self.preprocess(image).unsqueeze(0).to(self.device)
# Perform inference
with torch.no_grad():
outputs = self.model(image)
logits = outputs.logits
probabilities = F.softmax(logits, dim=1)
confidences, predicted = torch.max(probabilities, 1)
predicted_label = predicted.item()
confidence = confidences.item() * 100 # Convert to percentage format
return predicted_label, confidence