mnist / inference_util.py
tezuesh's picture
Upload folder using huggingface_hub
85dd38b verified
# inference.py
import torch
from torchvision import transforms, datasets
from PIL import Image
import json
from pathlib import Path
from model import MNISTModel
import os
import sys
class Inferencer:
def __init__(self, input_dir: str = 'input_data', output_dir: str = 'output_data'):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model, _ = self._load_model()
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def _load_model(self, model_path='best_model.pth'):
"""Load the trained model."""
model = MNISTModel().to(self.device)
model.load_state_dict(
torch.load(model_path, map_location=self.device, weights_only=True)
)
model.eval()
return model, self.device
def predict(self, input_tensor: torch.Tensor):
"""Make prediction on the input tensor."""
with torch.no_grad():
if input_tensor.dim() == 3:
input_tensor = input_tensor.unsqueeze(0)
input_tensor = input_tensor.to(self.device)
output = self.model(input_tensor)
probs = torch.softmax(output, dim=1)
prediction = output.argmax(1).item()
confidence = probs[0][prediction].item()
return prediction, confidence
def process_input(self):
"""Process all images in input directory."""
# Create output directory if it doesn't exist
os.makedirs(self.output_dir, exist_ok=True)
results = []
# Process each file in input directory
for file_path in sorted(self.input_dir.glob('*.pt')): # For tensor files
try:
# Load tensor
input_tensor = torch.load(file_path)
# Get prediction
prediction, confidence = self.predict(input_tensor)
results.append({
"filename": file_path.name,
"prediction": prediction,
"confidence": confidence
})
except Exception as e:
print(f"Error processing {file_path}: {str(e)}", file=sys.stderr)
# Save results
with open(self.output_dir / 'results.json', 'w') as f:
json.dump(results, f, indent=2)
return results
def main():
# Accept input/output directories as arguments
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input-dir', default='input_data')
parser.add_argument('--output-dir', default='output_data')
args = parser.parse_args()
inferencer = Inferencer(args.input_dir, args.output_dir)
results = inferencer.process_input()
print(f"Processed {len(results)} inputs")
if __name__ == "__main__":
main()