Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from transformers import ViTForImageClassification | |
from PIL import Image | |
class CustomModel: | |
def __init__(self): | |
# Explicitly set the device to CPU | |
self.device = torch.device('cpu') | |
# Load the pre-trained ViT | |
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device) | |
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device) | |
# Load model weights | |
self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True)) | |
self.model.eval() | |
# Resize the image and make it a tensor (add dimension) | |
self.preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
def predict(self, image: Image.Image): | |
# Preprocess the image | |
image = self.preprocess(image).unsqueeze(0).to(self.device) | |
# Perform inference | |
with torch.no_grad(): | |
outputs = self.model(image) | |
logits = outputs.logits | |
probabilities = F.softmax(logits, dim=1) | |
confidences, predicted = torch.max(probabilities, 1) | |
predicted_label = predicted.item() | |
confidence = confidences.item() * 100 # Convert to percentage format | |
return predicted_label, confidence |