|
import torch |
|
from torchvision import transforms |
|
from pathlib import Path |
|
import json |
|
import os |
|
import sys |
|
from model import MNISTModel |
|
from inference_util import Inferencer |
|
|
|
|
|
class InferenceWrapper: |
|
def __init__(self, model_path: str, input_dir: str = 'input_data', output_dir: str = 'output_data'): |
|
self.model_path = model_path |
|
self.inferencer = Inferencer(input_dir, output_dir) |
|
|
|
self.inferencer.model, _ = self.inferencer._load_model(model_path) |
|
|
|
def run_inference(self): |
|
"""Run inference using the specified model""" |
|
return self.inferencer.process_input() |
|
|
|
def main(): |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model-path', required=True, help='Path to the model weights') |
|
parser.add_argument('--input-dir', default='input_data') |
|
parser.add_argument('--output-dir', default='output_data') |
|
args = parser.parse_args() |
|
|
|
wrapper = InferenceWrapper(args.model_path, args.input_dir, args.output_dir) |
|
results = wrapper.run_inference() |
|
print(f"Processed {len(results)} inputs using model: {args.model_path}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|