tezuesh commited on
Commit
5cdf407
·
verified ·
1 Parent(s): 8f01982

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +87 -0
inference.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import torch
3
+ from torchvision import transforms, datasets
4
+ from PIL import Image
5
+ import json
6
+ from pathlib import Path
7
+ from model import MNISTModel
8
+ import os
9
+ import sys
10
+
11
+ class Inferencer:
12
+ def __init__(self, input_dir: str = 'input_data', output_dir: str = 'output_data'):
13
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ self.model, _ = self._load_model()
15
+ self.input_dir = Path(input_dir)
16
+ self.output_dir = Path(output_dir)
17
+ self.transform = transforms.Compose([
18
+ transforms.ToTensor(),
19
+ transforms.Normalize((0.1307,), (0.3081,))
20
+ ])
21
+
22
+ def _load_model(self, model_path='saved_models/best_model.pth'):
23
+ """Load the trained model."""
24
+ model = MNISTModel().to(self.device)
25
+ model.load_state_dict(
26
+ torch.load(model_path, map_location=self.device, weights_only=True)
27
+ )
28
+ model.eval()
29
+ return model, self.device
30
+
31
+ def predict(self, input_tensor: torch.Tensor):
32
+ """Make prediction on the input tensor."""
33
+ with torch.no_grad():
34
+ if input_tensor.dim() == 3:
35
+ input_tensor = input_tensor.unsqueeze(0)
36
+
37
+ input_tensor = input_tensor.to(self.device)
38
+ output = self.model(input_tensor)
39
+ probs = torch.softmax(output, dim=1)
40
+ prediction = output.argmax(1).item()
41
+ confidence = probs[0][prediction].item()
42
+ return prediction, confidence
43
+
44
+ def process_input(self):
45
+ """Process all images in input directory."""
46
+ # Create output directory if it doesn't exist
47
+ os.makedirs(self.output_dir, exist_ok=True)
48
+
49
+ results = []
50
+ # Process each file in input directory
51
+ for file_path in sorted(self.input_dir.glob('*.pt')): # For tensor files
52
+ try:
53
+ # Load tensor
54
+ input_tensor = torch.load(file_path)
55
+
56
+ # Get prediction
57
+ prediction, confidence = self.predict(input_tensor)
58
+
59
+ results.append({
60
+ "filename": file_path.name,
61
+ "prediction": prediction,
62
+ "confidence": confidence
63
+ })
64
+
65
+ except Exception as e:
66
+ print(f"Error processing {file_path}: {str(e)}", file=sys.stderr)
67
+
68
+ # Save results
69
+ with open(self.output_dir / 'results.json', 'w') as f:
70
+ json.dump(results, f, indent=2)
71
+
72
+ return results
73
+
74
+ def main():
75
+ # Accept input/output directories as arguments
76
+ import argparse
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument('--input-dir', default='input_data')
79
+ parser.add_argument('--output-dir', default='output_data')
80
+ args = parser.parse_args()
81
+
82
+ inferencer = Inferencer(args.input_dir, args.output_dir)
83
+ results = inferencer.process_input()
84
+ print(f"Processed {len(results)} inputs")
85
+
86
+ if __name__ == "__main__":
87
+ main()