mnist / inference.py
tezuesh's picture
Upload folder using huggingface_hub
2dea402 verified
import torch
from torchvision import transforms
from model import MNISTModel
class InferenceWrapper:
def __init__(self, model_path: str):
"""
Initialize the inference wrapper with a model path.
Args:
model_path (str): Path to the model weights file
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_path = model_path
self.model = self._load_model()
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def _load_model(self):
"""Load and return the model."""
model = MNISTModel().to(self.device)
model.load_state_dict(
torch.load(self.model_path, map_location=self.device, weights_only=True)
)
model.eval()
return model
def predict_tensor(self, input_tensor: torch.Tensor):
"""
Run inference on a single input tensor.
Args:
input_tensor (torch.Tensor): Input tensor of shape [1, 28, 28] or [N, 1, 28, 28]
Returns:
tuple: (prediction, confidence)
"""
with torch.no_grad():
if input_tensor.dim() == 3:
input_tensor = input_tensor.unsqueeze(0)
input_tensor = input_tensor.to(self.device)
output = self.model(input_tensor)
probs = torch.softmax(output, dim=1)
prediction = output.argmax(1).item()
confidence = probs[0][prediction].item()
return prediction, confidence
def predict_batch(self, input_tensors: torch.Tensor):
"""
Run inference on a batch of input tensors.
Args:
input_tensors (torch.Tensor): Batch of input tensors of shape [N, 1, 28, 28]
Returns:
tuple: (predictions, confidences)
"""
with torch.no_grad():
input_tensors = input_tensors.to(self.device)
output = self.model(input_tensors)
probs = torch.softmax(output, dim=1)
predictions = output.argmax(1)
confidences = torch.gather(probs, 1, predictions.unsqueeze(1)).squeeze(1)
return predictions.cpu().numpy(), confidences.cpu().numpy()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', required=True, help='Path to the model weights')
args = parser.parse_args()
# Example usage
wrapper = InferenceWrapper(args.model_path)
# Example single inference
test_input = torch.randn(1, 28, 28)
prediction, confidence = wrapper.predict_tensor(test_input)
print(f"Single prediction: {prediction}, confidence: {confidence:.4f}")
# Example batch inference
batch_input = torch.randn(4, 1, 28, 28)
predictions, confidences = wrapper.predict_batch(batch_input)
print(f"Batch predictions: {predictions}")
print(f"Batch confidences: {confidences}")
if __name__ == "__main__":
main()