|
|
|
import torch |
|
from torchvision import transforms, datasets |
|
from PIL import Image |
|
import json |
|
from pathlib import Path |
|
from model import MNISTModel |
|
import os |
|
import sys |
|
|
|
class Inferencer: |
|
def __init__(self, input_dir: str = 'input_data', output_dir: str = 'output_data'): |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.model, _ = self._load_model() |
|
self.input_dir = Path(input_dir) |
|
self.output_dir = Path(output_dir) |
|
self.transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.1307,), (0.3081,)) |
|
]) |
|
|
|
def _load_model(self, model_path='best_model.pth'): |
|
"""Load the trained model.""" |
|
model = MNISTModel().to(self.device) |
|
model.load_state_dict( |
|
torch.load(model_path, map_location=self.device, weights_only=True) |
|
) |
|
model.eval() |
|
return model, self.device |
|
|
|
def predict(self, input_tensor: torch.Tensor): |
|
"""Make prediction on the input tensor.""" |
|
with torch.no_grad(): |
|
if input_tensor.dim() == 3: |
|
input_tensor = input_tensor.unsqueeze(0) |
|
|
|
input_tensor = input_tensor.to(self.device) |
|
output = self.model(input_tensor) |
|
probs = torch.softmax(output, dim=1) |
|
prediction = output.argmax(1).item() |
|
confidence = probs[0][prediction].item() |
|
return prediction, confidence |
|
|
|
def process_input(self): |
|
"""Process all images in input directory.""" |
|
|
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
|
results = [] |
|
|
|
for file_path in sorted(self.input_dir.glob('*.pt')): |
|
try: |
|
|
|
input_tensor = torch.load(file_path) |
|
|
|
|
|
prediction, confidence = self.predict(input_tensor) |
|
|
|
results.append({ |
|
"filename": file_path.name, |
|
"prediction": prediction, |
|
"confidence": confidence |
|
}) |
|
|
|
except Exception as e: |
|
print(f"Error processing {file_path}: {str(e)}", file=sys.stderr) |
|
|
|
|
|
with open(self.output_dir / 'results.json', 'w') as f: |
|
json.dump(results, f, indent=2) |
|
|
|
return results |
|
|
|
def main(): |
|
|
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--input-dir', default='input_data') |
|
parser.add_argument('--output-dir', default='output_data') |
|
args = parser.parse_args() |
|
|
|
inferencer = Inferencer(args.input_dir, args.output_dir) |
|
results = inferencer.process_input() |
|
print(f"Processed {len(results)} inputs") |
|
|
|
if __name__ == "__main__": |
|
main() |