|
import torch |
|
from torchvision import transforms |
|
from model import MNISTModel |
|
|
|
|
|
class InferenceWrapper: |
|
def __init__(self, model_path: str): |
|
""" |
|
Initialize the inference wrapper with a model path. |
|
|
|
Args: |
|
model_path (str): Path to the model weights file |
|
""" |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.model_path = model_path |
|
self.model = self._load_model() |
|
self.transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.1307,), (0.3081,)) |
|
]) |
|
|
|
def _load_model(self): |
|
"""Load and return the model.""" |
|
model = MNISTModel().to(self.device) |
|
model.load_state_dict( |
|
torch.load(self.model_path, map_location=self.device, weights_only=True) |
|
) |
|
model.eval() |
|
return model |
|
|
|
def predict_tensor(self, input_tensor: torch.Tensor): |
|
""" |
|
Run inference on a single input tensor. |
|
|
|
Args: |
|
input_tensor (torch.Tensor): Input tensor of shape [1, 28, 28] or [N, 1, 28, 28] |
|
|
|
Returns: |
|
tuple: (prediction, confidence) |
|
""" |
|
with torch.no_grad(): |
|
if input_tensor.dim() == 3: |
|
input_tensor = input_tensor.unsqueeze(0) |
|
|
|
input_tensor = input_tensor.to(self.device) |
|
output = self.model(input_tensor) |
|
probs = torch.softmax(output, dim=1) |
|
prediction = output.argmax(1).item() |
|
confidence = probs[0][prediction].item() |
|
return prediction, confidence |
|
|
|
def predict_batch(self, input_tensors: torch.Tensor): |
|
""" |
|
Run inference on a batch of input tensors. |
|
|
|
Args: |
|
input_tensors (torch.Tensor): Batch of input tensors of shape [N, 1, 28, 28] |
|
|
|
Returns: |
|
tuple: (predictions, confidences) |
|
""" |
|
with torch.no_grad(): |
|
input_tensors = input_tensors.to(self.device) |
|
output = self.model(input_tensors) |
|
probs = torch.softmax(output, dim=1) |
|
predictions = output.argmax(1) |
|
confidences = torch.gather(probs, 1, predictions.unsqueeze(1)).squeeze(1) |
|
return predictions.cpu().numpy(), confidences.cpu().numpy() |
|
|
|
|
|
def main(): |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model-path', required=True, help='Path to the model weights') |
|
args = parser.parse_args() |
|
|
|
|
|
wrapper = InferenceWrapper(args.model_path) |
|
|
|
|
|
test_input = torch.randn(1, 28, 28) |
|
prediction, confidence = wrapper.predict_tensor(test_input) |
|
print(f"Single prediction: {prediction}, confidence: {confidence:.4f}") |
|
|
|
|
|
batch_input = torch.randn(4, 1, 28, 28) |
|
predictions, confidences = wrapper.predict_batch(batch_input) |
|
print(f"Batch predictions: {predictions}") |
|
print(f"Batch confidences: {confidences}") |
|
|
|
if __name__ == "__main__": |
|
main() |