File size: 3,060 Bytes
85dd38b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# inference.py
import torch
from torchvision import transforms, datasets
from PIL import Image
import json
from pathlib import Path
from model import MNISTModel
import os
import sys

class Inferencer:
    def __init__(self, input_dir: str = 'input_data', output_dir: str = 'output_data'):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model, _ = self._load_model()
        self.input_dir = Path(input_dir)
        self.output_dir = Path(output_dir)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
    def _load_model(self, model_path='best_model.pth'):
        """Load the trained model."""
        model = MNISTModel().to(self.device)
        model.load_state_dict(
            torch.load(model_path, map_location=self.device, weights_only=True)
        )
        model.eval()
        return model, self.device
    
    def predict(self, input_tensor: torch.Tensor):
        """Make prediction on the input tensor."""
        with torch.no_grad():
            if input_tensor.dim() == 3:
                input_tensor = input_tensor.unsqueeze(0)
            
            input_tensor = input_tensor.to(self.device)
            output = self.model(input_tensor)
            probs = torch.softmax(output, dim=1)
            prediction = output.argmax(1).item()
            confidence = probs[0][prediction].item()
            return prediction, confidence

    def process_input(self):
        """Process all images in input directory."""
        # Create output directory if it doesn't exist
        os.makedirs(self.output_dir, exist_ok=True)
        
        results = []
        # Process each file in input directory
        for file_path in sorted(self.input_dir.glob('*.pt')):  # For tensor files
            try:
                # Load tensor
                input_tensor = torch.load(file_path)
                
                # Get prediction
                prediction, confidence = self.predict(input_tensor)
                
                results.append({
                    "filename": file_path.name,
                    "prediction": prediction,
                    "confidence": confidence
                })
                
            except Exception as e:
                print(f"Error processing {file_path}: {str(e)}", file=sys.stderr)
        
        # Save results
        with open(self.output_dir / 'results.json', 'w') as f:
            json.dump(results, f, indent=2)
        
        return results

def main():
    # Accept input/output directories as arguments
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--input-dir', default='input_data')
    parser.add_argument('--output-dir', default='output_data')
    args = parser.parse_args()
    
    inferencer = Inferencer(args.input_dir, args.output_dir)
    results = inferencer.process_input()
    print(f"Processed {len(results)} inputs")

if __name__ == "__main__":
    main()