benjaminStreltzin commited on
Commit
22aae95
·
verified ·
1 Parent(s): 2dc8dc5

Update vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +4 -4
vit_model_test.py CHANGED
@@ -10,7 +10,7 @@ class CustomModel:
10
  # Explicitly set the device to CPU
11
  self.device = torch.device('cpu')
12
 
13
- # Load the pre-trained ViT model and move it to CPU
14
  self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
15
  self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
16
 
@@ -18,7 +18,7 @@ class CustomModel:
18
  self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True))
19
  self.model.eval()
20
 
21
- # Define the image preprocessing pipeline
22
  self.preprocess = transforms.Compose([
23
  transforms.Resize((224, 224)),
24
  transforms.ToTensor()
@@ -26,7 +26,7 @@ class CustomModel:
26
 
27
  def predict(self, image: Image.Image):
28
  # Preprocess the image
29
- image = self.preprocess(image).unsqueeze(0).to(self.device) # Add batch dimension
30
 
31
  # Perform inference
32
  with torch.no_grad():
@@ -35,6 +35,6 @@ class CustomModel:
35
  probabilities = F.softmax(logits, dim=1)
36
  confidences, predicted = torch.max(probabilities, 1)
37
  predicted_label = predicted.item()
38
- confidence = confidences.item() * 100 # Convert to percentage
39
 
40
  return predicted_label, confidence
 
10
  # Explicitly set the device to CPU
11
  self.device = torch.device('cpu')
12
 
13
+ # Load the pre-trained ViT
14
  self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
15
  self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
16
 
 
18
  self.model.load_state_dict(torch.load('trained_model.pth', map_location=self.device, weights_only=True))
19
  self.model.eval()
20
 
21
+ # Resize the image and make it a tensor (add dimension)
22
  self.preprocess = transforms.Compose([
23
  transforms.Resize((224, 224)),
24
  transforms.ToTensor()
 
26
 
27
  def predict(self, image: Image.Image):
28
  # Preprocess the image
29
+ image = self.preprocess(image).unsqueeze(0).to(self.device)
30
 
31
  # Perform inference
32
  with torch.no_grad():
 
35
  probabilities = F.softmax(logits, dim=1)
36
  confidences, predicted = torch.max(probabilities, 1)
37
  predicted_label = predicted.item()
38
+ confidence = confidences.item() * 100 # Convert to percentage format
39
 
40
  return predicted_label, confidence