File size: 1,222 Bytes
5cdf407
85dd38b
5cdf407
85dd38b
5cdf407
 
85dd38b
 
5cdf407
 
85dd38b
 
 
 
 
 
 
 
 
 
5cdf407
 
 
 
85dd38b
5cdf407
 
 
 
85dd38b
 
 
5cdf407
 
85dd38b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torchvision import transforms
from pathlib import Path
import json
import os
import sys
from model import MNISTModel
from inference_util import Inferencer


class InferenceWrapper:
    def __init__(self, model_path: str, input_dir: str = 'input_data', output_dir: str = 'output_data'):
        self.model_path = model_path
        self.inferencer = Inferencer(input_dir, output_dir)
        # Override the model with our specified model path
        self.inferencer.model, _ = self.inferencer._load_model(model_path)

    def run_inference(self):
        """Run inference using the specified model"""
        return self.inferencer.process_input()

def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-path', required=True, help='Path to the model weights')
    parser.add_argument('--input-dir', default='input_data')
    parser.add_argument('--output-dir', default='output_data')
    args = parser.parse_args()
    
    wrapper = InferenceWrapper(args.model_path, args.input_dir, args.output_dir)
    results = wrapper.run_inference()
    print(f"Processed {len(results)} inputs using model: {args.model_path}")

if __name__ == "__main__":
    main()