Spaces:
Sleeping
Sleeping
Update vit_model_test.py
Browse files- vit_model_test.py +4 -4
vit_model_test.py
CHANGED
@@ -10,7 +10,7 @@ class CustomModel:
|
|
10 |
# Explicitly set the device to CPU
|
11 |
self.device = torch.device('cpu')
|
12 |
|
13 |
-
# Load the pre-trained ViT
|
14 |
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
|
15 |
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
|
16 |
|
@@ -18,7 +18,7 @@ class CustomModel:
|
|
18 |
self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True))
|
19 |
self.model.eval()
|
20 |
|
21 |
-
#
|
22 |
self.preprocess = transforms.Compose([
|
23 |
transforms.Resize((224, 224)),
|
24 |
transforms.ToTensor()
|
@@ -26,7 +26,7 @@ class CustomModel:
|
|
26 |
|
27 |
def predict(self, image: Image.Image):
|
28 |
# Preprocess the image
|
29 |
-
image = self.preprocess(image).unsqueeze(0).to(self.device)
|
30 |
|
31 |
# Perform inference
|
32 |
with torch.no_grad():
|
@@ -35,6 +35,6 @@ class CustomModel:
|
|
35 |
probabilities = F.softmax(logits, dim=1)
|
36 |
confidences, predicted = torch.max(probabilities, 1)
|
37 |
predicted_label = predicted.item()
|
38 |
-
confidence = confidences.item() * 100 # Convert to percentage
|
39 |
|
40 |
return predicted_label, confidence
|
|
|
10 |
# Explicitly set the device to CPU
|
11 |
self.device = torch.device('cpu')
|
12 |
|
13 |
+
# Load the pre-trained ViT
|
14 |
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
|
15 |
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
|
16 |
|
|
|
18 |
self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True))
|
19 |
self.model.eval()
|
20 |
|
21 |
+
# Resize the image and make it a tensor (add dimension)
|
22 |
self.preprocess = transforms.Compose([
|
23 |
transforms.Resize((224, 224)),
|
24 |
transforms.ToTensor()
|
|
|
26 |
|
27 |
def predict(self, image: Image.Image):
|
28 |
# Preprocess the image
|
29 |
+
image = self.preprocess(image).unsqueeze(0).to(self.device)
|
30 |
|
31 |
# Perform inference
|
32 |
with torch.no_grad():
|
|
|
35 |
probabilities = F.softmax(logits, dim=1)
|
36 |
confidences, predicted = torch.max(probabilities, 1)
|
37 |
predicted_label = predicted.item()
|
38 |
+
confidence = confidences.item() * 100 # Convert to percentage format
|
39 |
|
40 |
return predicted_label, confidence
|