Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from transformers import ViTForImageClassification | |
| from PIL import Image | |
| class CustomModel: | |
| def __init__(self): | |
| # Explicitly set the device to CPU | |
| self.device = torch.device('cpu') | |
| # Load the pre-trained ViT | |
| self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device) | |
| self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device) | |
| # Load model weights | |
| self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True)) | |
| self.model.eval() | |
| # Resize the image and make it a tensor (add dimension) | |
| self.preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| def predict(self, image: Image.Image): | |
| # Preprocess the image | |
| image = self.preprocess(image).unsqueeze(0).to(self.device) | |
| # Perform inference | |
| with torch.no_grad(): | |
| outputs = self.model(image) | |
| logits = outputs.logits | |
| probabilities = F.softmax(logits, dim=1) | |
| confidences, predicted = torch.max(probabilities, 1) | |
| predicted_label = predicted.item() | |
| confidence = confidences.item() * 100 # Convert to percentage format | |
| return predicted_label, confidence |