File size: 3,131 Bytes
5cdf407 85dd38b 5cdf407 85dd38b 2dea402 85dd38b 2dea402 85dd38b 5cdf407 85dd38b 5cdf407 2dea402 5cdf407 2dea402 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
from torchvision import transforms
from model import MNISTModel
class InferenceWrapper:
def __init__(self, model_path: str):
"""
Initialize the inference wrapper with a model path.
Args:
model_path (str): Path to the model weights file
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_path = model_path
self.model = self._load_model()
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def _load_model(self):
"""Load and return the model."""
model = MNISTModel().to(self.device)
model.load_state_dict(
torch.load(self.model_path, map_location=self.device, weights_only=True)
)
model.eval()
return model
def predict_tensor(self, input_tensor: torch.Tensor):
"""
Run inference on a single input tensor.
Args:
input_tensor (torch.Tensor): Input tensor of shape [1, 28, 28] or [N, 1, 28, 28]
Returns:
tuple: (prediction, confidence)
"""
with torch.no_grad():
if input_tensor.dim() == 3:
input_tensor = input_tensor.unsqueeze(0)
input_tensor = input_tensor.to(self.device)
output = self.model(input_tensor)
probs = torch.softmax(output, dim=1)
prediction = output.argmax(1).item()
confidence = probs[0][prediction].item()
return prediction, confidence
def predict_batch(self, input_tensors: torch.Tensor):
"""
Run inference on a batch of input tensors.
Args:
input_tensors (torch.Tensor): Batch of input tensors of shape [N, 1, 28, 28]
Returns:
tuple: (predictions, confidences)
"""
with torch.no_grad():
input_tensors = input_tensors.to(self.device)
output = self.model(input_tensors)
probs = torch.softmax(output, dim=1)
predictions = output.argmax(1)
confidences = torch.gather(probs, 1, predictions.unsqueeze(1)).squeeze(1)
return predictions.cpu().numpy(), confidences.cpu().numpy()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', required=True, help='Path to the model weights')
args = parser.parse_args()
# Example usage
wrapper = InferenceWrapper(args.model_path)
# Example single inference
test_input = torch.randn(1, 28, 28)
prediction, confidence = wrapper.predict_tensor(test_input)
print(f"Single prediction: {prediction}, confidence: {confidence:.4f}")
# Example batch inference
batch_input = torch.randn(4, 1, 28, 28)
predictions, confidences = wrapper.predict_batch(batch_input)
print(f"Batch predictions: {predictions}")
print(f"Batch confidences: {confidences}")
if __name__ == "__main__":
main() |