import torch from torchvision import transforms from model import MNISTModel class InferenceWrapper: def __init__(self, model_path: str): """ Initialize the inference wrapper with a model path. Args: model_path (str): Path to the model weights file """ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model_path = model_path self.model = self._load_model() self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) def _load_model(self): """Load and return the model.""" model = MNISTModel().to(self.device) model.load_state_dict( torch.load(self.model_path, map_location=self.device, weights_only=True) ) model.eval() return model def predict_tensor(self, input_tensor: torch.Tensor): """ Run inference on a single input tensor. Args: input_tensor (torch.Tensor): Input tensor of shape [1, 28, 28] or [N, 1, 28, 28] Returns: tuple: (prediction, confidence) """ with torch.no_grad(): if input_tensor.dim() == 3: input_tensor = input_tensor.unsqueeze(0) input_tensor = input_tensor.to(self.device) output = self.model(input_tensor) probs = torch.softmax(output, dim=1) prediction = output.argmax(1).item() confidence = probs[0][prediction].item() return prediction, confidence def predict_batch(self, input_tensors: torch.Tensor): """ Run inference on a batch of input tensors. Args: input_tensors (torch.Tensor): Batch of input tensors of shape [N, 1, 28, 28] Returns: tuple: (predictions, confidences) """ with torch.no_grad(): input_tensors = input_tensors.to(self.device) output = self.model(input_tensors) probs = torch.softmax(output, dim=1) predictions = output.argmax(1) confidences = torch.gather(probs, 1, predictions.unsqueeze(1)).squeeze(1) return predictions.cpu().numpy(), confidences.cpu().numpy() def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--model-path', required=True, help='Path to the model weights') args = parser.parse_args() # Example usage wrapper = InferenceWrapper(args.model_path) # Example single inference test_input = torch.randn(1, 28, 28) prediction, confidence = wrapper.predict_tensor(test_input) print(f"Single prediction: {prediction}, confidence: {confidence:.4f}") # Example batch inference batch_input = torch.randn(4, 1, 28, 28) predictions, confidences = wrapper.predict_batch(batch_input) print(f"Batch predictions: {predictions}") print(f"Batch confidences: {confidences}") if __name__ == "__main__": main()