Spaces:
Sleeping
Sleeping
File size: 1,490 Bytes
152bbff 6976bb1 152bbff 22aae95 152bbff 6976bb1 dbc65b9 152bbff 22aae95 152bbff 22aae95 152bbff 22aae95 152bbff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import ViTForImageClassification
from PIL import Image
class CustomModel:
def __init__(self):
# Explicitly set the device to CPU
self.device = torch.device('cpu')
# Load the pre-trained ViT
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
# Load model weights
self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True))
self.model.eval()
# Resize the image and make it a tensor (add dimension)
self.preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def predict(self, image: Image.Image):
# Preprocess the image
image = self.preprocess(image).unsqueeze(0).to(self.device)
# Perform inference
with torch.no_grad():
outputs = self.model(image)
logits = outputs.logits
probabilities = F.softmax(logits, dim=1)
confidences, predicted = torch.max(probabilities, 1)
predicted_label = predicted.item()
confidence = confidences.item() * 100 # Convert to percentage format
return predicted_label, confidence |