File size: 1,222 Bytes
5cdf407 85dd38b 5cdf407 85dd38b 5cdf407 85dd38b 5cdf407 85dd38b 5cdf407 85dd38b 5cdf407 85dd38b 5cdf407 85dd38b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from torchvision import transforms
from pathlib import Path
import json
import os
import sys
from model import MNISTModel
from inference_util import Inferencer
class InferenceWrapper:
def __init__(self, model_path: str, input_dir: str = 'input_data', output_dir: str = 'output_data'):
self.model_path = model_path
self.inferencer = Inferencer(input_dir, output_dir)
# Override the model with our specified model path
self.inferencer.model, _ = self.inferencer._load_model(model_path)
def run_inference(self):
"""Run inference using the specified model"""
return self.inferencer.process_input()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', required=True, help='Path to the model weights')
parser.add_argument('--input-dir', default='input_data')
parser.add_argument('--output-dir', default='output_data')
args = parser.parse_args()
wrapper = InferenceWrapper(args.model_path, args.input_dir, args.output_dir)
results = wrapper.run_inference()
print(f"Processed {len(results)} inputs using model: {args.model_path}")
if __name__ == "__main__":
main()
|