import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from transformers import ViTForImageClassification from PIL import Image class CustomModel: def __init__(self): # Explicitly set the device to CPU self.device = torch.device('cpu') # Load the pre-trained ViT self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device) self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device) # Load model weights self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True)) self.model.eval() # Resize the image and make it a tensor (add dimension) self.preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) def predict(self, image: Image.Image): # Preprocess the image image = self.preprocess(image).unsqueeze(0).to(self.device) # Perform inference with torch.no_grad(): outputs = self.model(image) logits = outputs.logits probabilities = F.softmax(logits, dim=1) confidences, predicted = torch.max(probabilities, 1) predicted_label = predicted.item() confidence = confidences.item() * 100 # Convert to percentage format return predicted_label, confidence