tezuesh commited on
Commit
2dea402
·
verified ·
1 Parent(s): 28bc2d0

Upload folder using huggingface_hub

Browse files
Files changed (45) hide show
  1. __pycache__/model.cpython-310.pyc +0 -0
  2. inference.py +75 -18
  3. mnist/.gitattributes +35 -0
  4. mnist/Python-3.10.0/Lib/ctypes/__pycache__/__init__.cpython-310.pyc +0 -0
  5. mnist/Python-3.10.0/Lib/ctypes/__pycache__/_endian.cpython-310.pyc +0 -0
  6. mnist/Python-3.10.0/Lib/ensurepip/__pycache__/__init__.cpython-310.pyc +0 -0
  7. mnist/Python-3.10.0/Lib/ensurepip/__pycache__/__main__.cpython-310.pyc +0 -0
  8. mnist/Python-3.10.0/Lib/ensurepip/_bundled/__pycache__/__init__.cpython-310.pyc +0 -0
  9. mnist/Python-3.10.0/Lib/html/__pycache__/__init__.cpython-310.pyc +0 -0
  10. mnist/Python-3.10.0/Lib/html/__pycache__/entities.cpython-310.pyc +0 -0
  11. mnist/Python-3.10.0/Lib/html/__pycache__/parser.cpython-310.pyc +0 -0
  12. mnist/Python-3.10.0/Lib/http/__pycache__/__init__.cpython-310.pyc +0 -0
  13. mnist/Python-3.10.0/Lib/http/__pycache__/client.cpython-310.pyc +0 -0
  14. mnist/Python-3.10.0/Lib/http/__pycache__/cookiejar.cpython-310.pyc +0 -0
  15. mnist/Python-3.10.0/Lib/http/__pycache__/cookies.cpython-310.pyc +0 -0
  16. mnist/Python-3.10.0/Lib/lib2to3/__pycache__/__init__.cpython-310.pyc +0 -0
  17. mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/__init__.cpython-310.pyc +0 -0
  18. mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/driver.cpython-310.pyc +0 -0
  19. mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/grammar.cpython-310.pyc +0 -0
  20. mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/parse.cpython-310.pyc +0 -0
  21. mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/pgen.cpython-310.pyc +0 -0
  22. mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/token.cpython-310.pyc +0 -0
  23. mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/tokenize.cpython-310.pyc +0 -0
  24. mnist/Python-3.10.0/Lib/urllib/__pycache__/__init__.cpython-310.pyc +0 -0
  25. mnist/Python-3.10.0/Lib/urllib/__pycache__/error.cpython-310.pyc +0 -0
  26. mnist/Python-3.10.0/Lib/urllib/__pycache__/parse.cpython-310.pyc +0 -0
  27. mnist/Python-3.10.0/Lib/urllib/__pycache__/request.cpython-310.pyc +0 -0
  28. mnist/Python-3.10.0/Lib/urllib/__pycache__/response.cpython-310.pyc +0 -0
  29. mnist/Python-3.10.0/Lib/xml/parsers/__pycache__/__init__.cpython-310.pyc +0 -0
  30. mnist/Python-3.10.0/Lib/xml/parsers/__pycache__/expat.cpython-310.pyc +0 -0
  31. mnist/Python-3.10.0/Lib/xml/sax/__pycache__/__init__.cpython-310.pyc +0 -0
  32. mnist/Python-3.10.0/Lib/xml/sax/__pycache__/_exceptions.cpython-310.pyc +0 -0
  33. mnist/Python-3.10.0/Lib/xml/sax/__pycache__/handler.cpython-310.pyc +0 -0
  34. mnist/Python-3.10.0/Lib/xml/sax/__pycache__/saxutils.cpython-310.pyc +0 -0
  35. mnist/Python-3.10.0/Lib/xml/sax/__pycache__/xmlreader.cpython-310.pyc +0 -0
  36. mnist/Python-3.10.0/Lib/xmlrpc/__pycache__/__init__.cpython-310.pyc +0 -0
  37. mnist/Python-3.10.0/Lib/xmlrpc/__pycache__/client.cpython-310.pyc +0 -0
  38. mnist/__init__.py +0 -0
  39. mnist/__pycache__/model.cpython-310.pyc +0 -0
  40. mnist/best_model.pth +3 -0
  41. mnist/config.json +9 -0
  42. mnist/inference.py +92 -0
  43. mnist/inference_util.py +87 -0
  44. mnist/model.py +16 -0
  45. mnist/random2.txt +1 -0
__pycache__/model.cpython-310.pyc ADDED
Binary file (877 Bytes). View file
 
inference.py CHANGED
@@ -1,35 +1,92 @@
1
  import torch
2
  from torchvision import transforms
3
- from pathlib import Path
4
- import json
5
- import os
6
- import sys
7
  from model import MNISTModel
8
- from inference_util import Inferencer
9
 
10
 
11
  class InferenceWrapper:
12
- def __init__(self, model_path: str, input_dir: str = 'input_data', output_dir: str = 'output_data'):
 
 
 
 
 
 
 
13
  self.model_path = model_path
14
- self.inferencer = Inferencer(input_dir, output_dir)
15
- # Override the model with our specified model path
16
- self.inferencer.model, _ = self.inferencer._load_model(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def run_inference(self):
19
- """Run inference using the specified model"""
20
- return self.inferencer.process_input()
21
 
22
  def main():
23
  import argparse
24
  parser = argparse.ArgumentParser()
25
  parser.add_argument('--model-path', required=True, help='Path to the model weights')
26
- parser.add_argument('--input-dir', default='input_data')
27
- parser.add_argument('--output-dir', default='output_data')
28
  args = parser.parse_args()
29
 
30
- wrapper = InferenceWrapper(args.model_path, args.input_dir, args.output_dir)
31
- results = wrapper.run_inference()
32
- print(f"Processed {len(results)} inputs using model: {args.model_path}")
 
 
 
 
 
 
 
 
 
 
33
 
34
  if __name__ == "__main__":
35
- main()
 
1
  import torch
2
  from torchvision import transforms
 
 
 
 
3
  from model import MNISTModel
 
4
 
5
 
6
  class InferenceWrapper:
7
+ def __init__(self, model_path: str):
8
+ """
9
+ Initialize the inference wrapper with a model path.
10
+
11
+ Args:
12
+ model_path (str): Path to the model weights file
13
+ """
14
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  self.model_path = model_path
16
+ self.model = self._load_model()
17
+ self.transform = transforms.Compose([
18
+ transforms.ToTensor(),
19
+ transforms.Normalize((0.1307,), (0.3081,))
20
+ ])
21
+
22
+ def _load_model(self):
23
+ """Load and return the model."""
24
+ model = MNISTModel().to(self.device)
25
+ model.load_state_dict(
26
+ torch.load(self.model_path, map_location=self.device, weights_only=True)
27
+ )
28
+ model.eval()
29
+ return model
30
+
31
+ def predict_tensor(self, input_tensor: torch.Tensor):
32
+ """
33
+ Run inference on a single input tensor.
34
+
35
+ Args:
36
+ input_tensor (torch.Tensor): Input tensor of shape [1, 28, 28] or [N, 1, 28, 28]
37
+
38
+ Returns:
39
+ tuple: (prediction, confidence)
40
+ """
41
+ with torch.no_grad():
42
+ if input_tensor.dim() == 3:
43
+ input_tensor = input_tensor.unsqueeze(0)
44
+
45
+ input_tensor = input_tensor.to(self.device)
46
+ output = self.model(input_tensor)
47
+ probs = torch.softmax(output, dim=1)
48
+ prediction = output.argmax(1).item()
49
+ confidence = probs[0][prediction].item()
50
+ return prediction, confidence
51
+
52
+ def predict_batch(self, input_tensors: torch.Tensor):
53
+ """
54
+ Run inference on a batch of input tensors.
55
+
56
+ Args:
57
+ input_tensors (torch.Tensor): Batch of input tensors of shape [N, 1, 28, 28]
58
+
59
+ Returns:
60
+ tuple: (predictions, confidences)
61
+ """
62
+ with torch.no_grad():
63
+ input_tensors = input_tensors.to(self.device)
64
+ output = self.model(input_tensors)
65
+ probs = torch.softmax(output, dim=1)
66
+ predictions = output.argmax(1)
67
+ confidences = torch.gather(probs, 1, predictions.unsqueeze(1)).squeeze(1)
68
+ return predictions.cpu().numpy(), confidences.cpu().numpy()
69
 
 
 
 
70
 
71
  def main():
72
  import argparse
73
  parser = argparse.ArgumentParser()
74
  parser.add_argument('--model-path', required=True, help='Path to the model weights')
 
 
75
  args = parser.parse_args()
76
 
77
+ # Example usage
78
+ wrapper = InferenceWrapper(args.model_path)
79
+
80
+ # Example single inference
81
+ test_input = torch.randn(1, 28, 28)
82
+ prediction, confidence = wrapper.predict_tensor(test_input)
83
+ print(f"Single prediction: {prediction}, confidence: {confidence:.4f}")
84
+
85
+ # Example batch inference
86
+ batch_input = torch.randn(4, 1, 28, 28)
87
+ predictions, confidences = wrapper.predict_batch(batch_input)
88
+ print(f"Batch predictions: {predictions}")
89
+ print(f"Batch confidences: {confidences}")
90
 
91
  if __name__ == "__main__":
92
+ main()
mnist/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
mnist/Python-3.10.0/Lib/ctypes/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
mnist/Python-3.10.0/Lib/ctypes/__pycache__/_endian.cpython-310.pyc ADDED
Binary file (1.95 kB). View file
 
mnist/Python-3.10.0/Lib/ensurepip/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (6.3 kB). View file
 
mnist/Python-3.10.0/Lib/ensurepip/__pycache__/__main__.cpython-310.pyc ADDED
Binary file (287 Bytes). View file
 
mnist/Python-3.10.0/Lib/ensurepip/_bundled/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (195 Bytes). View file
 
mnist/Python-3.10.0/Lib/html/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.38 kB). View file
 
mnist/Python-3.10.0/Lib/html/__pycache__/entities.cpython-310.pyc ADDED
Binary file (144 kB). View file
 
mnist/Python-3.10.0/Lib/html/__pycache__/parser.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
mnist/Python-3.10.0/Lib/http/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (6.47 kB). View file
 
mnist/Python-3.10.0/Lib/http/__pycache__/client.cpython-310.pyc ADDED
Binary file (35.1 kB). View file
 
mnist/Python-3.10.0/Lib/http/__pycache__/cookiejar.cpython-310.pyc ADDED
Binary file (53.5 kB). View file
 
mnist/Python-3.10.0/Lib/http/__pycache__/cookies.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
mnist/Python-3.10.0/Lib/lib2to3/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (355 Bytes). View file
 
mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (220 Bytes). View file
 
mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/driver.cpython-310.pyc ADDED
Binary file (5.2 kB). View file
 
mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/grammar.cpython-310.pyc ADDED
Binary file (5.77 kB). View file
 
mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/parse.cpython-310.pyc ADDED
Binary file (6.57 kB). View file
 
mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/pgen.cpython-310.pyc ADDED
Binary file (9.91 kB). View file
 
mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/token.cpython-310.pyc ADDED
Binary file (1.94 kB). View file
 
mnist/Python-3.10.0/Lib/lib2to3/pgen2/__pycache__/tokenize.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
mnist/Python-3.10.0/Lib/urllib/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (183 Bytes). View file
 
mnist/Python-3.10.0/Lib/urllib/__pycache__/error.cpython-310.pyc ADDED
Binary file (2.87 kB). View file
 
mnist/Python-3.10.0/Lib/urllib/__pycache__/parse.cpython-310.pyc ADDED
Binary file (33.9 kB). View file
 
mnist/Python-3.10.0/Lib/urllib/__pycache__/request.cpython-310.pyc ADDED
Binary file (71.3 kB). View file
 
mnist/Python-3.10.0/Lib/urllib/__pycache__/response.cpython-310.pyc ADDED
Binary file (3.52 kB). View file
 
mnist/Python-3.10.0/Lib/xml/parsers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (360 Bytes). View file
 
mnist/Python-3.10.0/Lib/xml/parsers/__pycache__/expat.cpython-310.pyc ADDED
Binary file (389 Bytes). View file
 
mnist/Python-3.10.0/Lib/xml/sax/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.26 kB). View file
 
mnist/Python-3.10.0/Lib/xml/sax/__pycache__/_exceptions.cpython-310.pyc ADDED
Binary file (5.43 kB). View file
 
mnist/Python-3.10.0/Lib/xml/sax/__pycache__/handler.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
mnist/Python-3.10.0/Lib/xml/sax/__pycache__/saxutils.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
mnist/Python-3.10.0/Lib/xml/sax/__pycache__/xmlreader.cpython-310.pyc ADDED
Binary file (16.5 kB). View file
 
mnist/Python-3.10.0/Lib/xmlrpc/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (183 Bytes). View file
 
mnist/Python-3.10.0/Lib/xmlrpc/__pycache__/client.cpython-310.pyc ADDED
Binary file (34.3 kB). View file
 
mnist/__init__.py ADDED
File without changes
mnist/__pycache__/model.cpython-310.pyc ADDED
Binary file (877 Bytes). View file
 
mnist/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c72423094523210c88faa0306abaf81f4352b99b3865d2a80671a361eae0836
3
+ size 131
mnist/config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "mnist_classifier",
3
+ "input_size": [1, 28, 28],
4
+ "hidden_size": 128,
5
+ "num_classes": 10,
6
+ "dropout": 0.5,
7
+ "mean": 0.1307,
8
+ "std": 0.3081
9
+ }
mnist/inference.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from model import MNISTModel
4
+
5
+
6
+ class InferenceWrapper:
7
+ def __init__(self, model_path: str):
8
+ """
9
+ Initialize the inference wrapper with a model path.
10
+
11
+ Args:
12
+ model_path (str): Path to the model weights file
13
+ """
14
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ self.model_path = model_path
16
+ self.model = self._load_model()
17
+ self.transform = transforms.Compose([
18
+ transforms.ToTensor(),
19
+ transforms.Normalize((0.1307,), (0.3081,))
20
+ ])
21
+
22
+ def _load_model(self):
23
+ """Load and return the model."""
24
+ model = MNISTModel().to(self.device)
25
+ model.load_state_dict(
26
+ torch.load(self.model_path, map_location=self.device, weights_only=True)
27
+ )
28
+ model.eval()
29
+ return model
30
+
31
+ def predict_tensor(self, input_tensor: torch.Tensor):
32
+ """
33
+ Run inference on a single input tensor.
34
+
35
+ Args:
36
+ input_tensor (torch.Tensor): Input tensor of shape [1, 28, 28] or [N, 1, 28, 28]
37
+
38
+ Returns:
39
+ tuple: (prediction, confidence)
40
+ """
41
+ with torch.no_grad():
42
+ if input_tensor.dim() == 3:
43
+ input_tensor = input_tensor.unsqueeze(0)
44
+
45
+ input_tensor = input_tensor.to(self.device)
46
+ output = self.model(input_tensor)
47
+ probs = torch.softmax(output, dim=1)
48
+ prediction = output.argmax(1).item()
49
+ confidence = probs[0][prediction].item()
50
+ return prediction, confidence
51
+
52
+ def predict_batch(self, input_tensors: torch.Tensor):
53
+ """
54
+ Run inference on a batch of input tensors.
55
+
56
+ Args:
57
+ input_tensors (torch.Tensor): Batch of input tensors of shape [N, 1, 28, 28]
58
+
59
+ Returns:
60
+ tuple: (predictions, confidences)
61
+ """
62
+ with torch.no_grad():
63
+ input_tensors = input_tensors.to(self.device)
64
+ output = self.model(input_tensors)
65
+ probs = torch.softmax(output, dim=1)
66
+ predictions = output.argmax(1)
67
+ confidences = torch.gather(probs, 1, predictions.unsqueeze(1)).squeeze(1)
68
+ return predictions.cpu().numpy(), confidences.cpu().numpy()
69
+
70
+
71
+ def main():
72
+ import argparse
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument('--model-path', required=True, help='Path to the model weights')
75
+ args = parser.parse_args()
76
+
77
+ # Example usage
78
+ wrapper = InferenceWrapper(args.model_path)
79
+
80
+ # Example single inference
81
+ test_input = torch.randn(1, 28, 28)
82
+ prediction, confidence = wrapper.predict_tensor(test_input)
83
+ print(f"Single prediction: {prediction}, confidence: {confidence:.4f}")
84
+
85
+ # Example batch inference
86
+ batch_input = torch.randn(4, 1, 28, 28)
87
+ predictions, confidences = wrapper.predict_batch(batch_input)
88
+ print(f"Batch predictions: {predictions}")
89
+ print(f"Batch confidences: {confidences}")
90
+
91
+ if __name__ == "__main__":
92
+ main()
mnist/inference_util.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import torch
3
+ from torchvision import transforms, datasets
4
+ from PIL import Image
5
+ import json
6
+ from pathlib import Path
7
+ from model import MNISTModel
8
+ import os
9
+ import sys
10
+
11
+ class Inferencer:
12
+ def __init__(self, input_dir: str = 'input_data', output_dir: str = 'output_data'):
13
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ self.model, _ = self._load_model()
15
+ self.input_dir = Path(input_dir)
16
+ self.output_dir = Path(output_dir)
17
+ self.transform = transforms.Compose([
18
+ transforms.ToTensor(),
19
+ transforms.Normalize((0.1307,), (0.3081,))
20
+ ])
21
+
22
+ def _load_model(self, model_path='best_model.pth'):
23
+ """Load the trained model."""
24
+ model = MNISTModel().to(self.device)
25
+ model.load_state_dict(
26
+ torch.load(model_path, map_location=self.device, weights_only=True)
27
+ )
28
+ model.eval()
29
+ return model, self.device
30
+
31
+ def predict(self, input_tensor: torch.Tensor):
32
+ """Make prediction on the input tensor."""
33
+ with torch.no_grad():
34
+ if input_tensor.dim() == 3:
35
+ input_tensor = input_tensor.unsqueeze(0)
36
+
37
+ input_tensor = input_tensor.to(self.device)
38
+ output = self.model(input_tensor)
39
+ probs = torch.softmax(output, dim=1)
40
+ prediction = output.argmax(1).item()
41
+ confidence = probs[0][prediction].item()
42
+ return prediction, confidence
43
+
44
+ def process_input(self):
45
+ """Process all images in input directory."""
46
+ # Create output directory if it doesn't exist
47
+ os.makedirs(self.output_dir, exist_ok=True)
48
+
49
+ results = []
50
+ # Process each file in input directory
51
+ for file_path in sorted(self.input_dir.glob('*.pt')): # For tensor files
52
+ try:
53
+ # Load tensor
54
+ input_tensor = torch.load(file_path)
55
+
56
+ # Get prediction
57
+ prediction, confidence = self.predict(input_tensor)
58
+
59
+ results.append({
60
+ "filename": file_path.name,
61
+ "prediction": prediction,
62
+ "confidence": confidence
63
+ })
64
+
65
+ except Exception as e:
66
+ print(f"Error processing {file_path}: {str(e)}", file=sys.stderr)
67
+
68
+ # Save results
69
+ with open(self.output_dir / 'results.json', 'w') as f:
70
+ json.dump(results, f, indent=2)
71
+
72
+ return results
73
+
74
+ def main():
75
+ # Accept input/output directories as arguments
76
+ import argparse
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument('--input-dir', default='input_data')
79
+ parser.add_argument('--output-dir', default='output_data')
80
+ args = parser.parse_args()
81
+
82
+ inferencer = Inferencer(args.input_dir, args.output_dir)
83
+ results = inferencer.process_input()
84
+ print(f"Processed {len(results)} inputs")
85
+
86
+ if __name__ == "__main__":
87
+ main()
mnist/model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class MNISTModel(nn.Module):
5
+ def __init__(self):
6
+ super(MNISTModel, self).__init__()
7
+ self.fc1 = nn.Linear(28 * 28, 128) # MNIST images are 28x28
8
+ self.fc2 = nn.Linear(128, 10)
9
+ self.dropout = nn.Dropout(0.5)
10
+
11
+ def forward(self, x):
12
+ x = x.view(-1, 28 * 28) # Flatten the input
13
+ x = torch.relu(self.fc1(x))
14
+ x = self.dropout(x)
15
+ x = self.fc2(x)
16
+ return x
mnist/random2.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ [!] Failed to build directory tree: No such file or directory (os error 2)