mnist / inference.py
tezuesh's picture
Upload folder using huggingface_hub
85dd38b verified
raw
history blame
1.22 kB
import torch
from torchvision import transforms
from pathlib import Path
import json
import os
import sys
from model import MNISTModel
from inference_util import Inferencer
class InferenceWrapper:
def __init__(self, model_path: str, input_dir: str = 'input_data', output_dir: str = 'output_data'):
self.model_path = model_path
self.inferencer = Inferencer(input_dir, output_dir)
# Override the model with our specified model path
self.inferencer.model, _ = self.inferencer._load_model(model_path)
def run_inference(self):
"""Run inference using the specified model"""
return self.inferencer.process_input()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', required=True, help='Path to the model weights')
parser.add_argument('--input-dir', default='input_data')
parser.add_argument('--output-dir', default='output_data')
args = parser.parse_args()
wrapper = InferenceWrapper(args.model_path, args.input_dir, args.output_dir)
results = wrapper.run_inference()
print(f"Processed {len(results)} inputs using model: {args.model_path}")
if __name__ == "__main__":
main()