0xnu
/

Image Classification
Keras
vision
File size: 5,630 Bytes
9e430b7
 
 
 
980f6f8
9e430b7
 
058f4c8
980f6f8
530b58b
e99dff6
 
f5b9188
 
9facfab
 
bcfe26b
 
730ff55
 
 
 
 
 
 
 
 
a45bda7
 
 
 
730ff55
 
 
 
 
a45bda7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730ff55
 
 
a45bda7
 
 
 
 
 
 
 
730ff55
a45bda7
730ff55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a45bda7
730ff55
 
 
a45bda7
730ff55
 
 
a45bda7
 
 
730ff55
 
 
 
a45bda7
730ff55
 
 
 
 
 
 
 
 
a45bda7
730ff55
 
 
 
 
 
 
a45bda7
730ff55
 
 
 
 
 
 
 
 
 
 
 
a45bda7
 
 
 
 
 
 
 
730ff55
 
 
 
 
 
 
 
 
 
f5b9188
7d5ba4a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
---
license: apache-2.0
tags:
- vision
- image-classification
datasets:
- dmitva/the-mnist-database
inference: true
pipeline_tag: image-classification
widget:
- text: "Enter image URL"
  example: https://miro.medium.com/v2/resize:fit:720/format:webp/1*w7pBsjI3t3ZP-4Gdog-JdQ.png
---

The MNIST OCR (Optical Character Recognition) model is a deep learning model trained to recognise and classify handwritten digits from 0 to 9. This model is trained on the MNIST dataset, which consists of 60,000 small square 28×28 pixel grayscale images of handwritten single digits, making it highly accurate for recognising written, isolated digits in a similar style to those found in the training set.

![Training History](training_history.png "Training History")

### Install Packages

```sh
pip install numpy opencv-python requests pillow transformers tensorflow
```

### Usage

```python
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import numpy as np
import cv2
import requests
from PIL import Image
from io import BytesIO
from typing import List, Optional
from huggingface_hub import hf_hub_download
import tensorflow as tf
import pickle

class ImageTokenizer:
    def __init__(self):
        self.unique_pixels = set()
        self.pixel_to_token = {}
        self.token_to_pixel = {}

    def fit(self, images):
        for image in images:
            self.unique_pixels.update(np.unique(image))
        self.pixel_to_token = {pixel: i for i, pixel in enumerate(sorted(self.unique_pixels))}
        self.token_to_pixel = {i: pixel for pixel, i in self.pixel_to_token.items()}

    def tokenize(self, images):
        return np.vectorize(self.pixel_to_token.get)(images)

    def detokenize(self, tokens):
        return np.vectorize(self.token_to_pixel.get)(tokens)

class MNISTPredictor:
    def __init__(self, model_name):
        # Download the model and tokenizer files
        model_path = hf_hub_download(repo_id=model_name, filename="mnist_model.keras")
        tokenizer_path = hf_hub_download(repo_id=model_name, filename="mnist_tokenizer.pkl")

        # Load the model and tokenizer
        self.model = keras.models.load_model(model_path)
        with open(tokenizer_path, 'rb') as tokenizer_file:
            self.tokenizer = pickle.load(tokenizer_file)

    def extract_features(self, image: Image.Image) -> List[np.ndarray]:
        """Extract features from the image for multiple digits."""
        # Convert to grayscale
        gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)

        # Apply Gaussian blur
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)

        # Apply adaptive thresholding
        thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)

        # Find contours
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        digit_images = []
        for contour in contours:
            # Filter small contours
            if cv2.contourArea(contour) > 50:  # Adjust this threshold as needed
                x, y, w, h = cv2.boundingRect(contour)
                roi = thresh[y:y+h, x:x+w]
                resized = cv2.resize(roi, (28, 28), interpolation=cv2.INTER_AREA)
                digit_images.append(resized.reshape((28, 28, 1)).astype('float32') / 255)

        return digit_images

    def predict(self, image: Image.Image) -> Optional[List[int]]:
        """Predict digits in the image."""
        try:
            digit_images = self.extract_features(image)
            tokenized_images = [self.tokenizer.tokenize(img) for img in digit_images]
            predictions = self.model.predict(np.array(tokenized_images), verbose=0)
            return np.argmax(predictions, axis=1).tolist()
        except Exception as e:
            print(f"Error during prediction: {e}")
            return None

def download_image(url: str) -> Optional[Image.Image]:
    """Download an image from a URL."""
    try:
        response = requests.get(url)
        response.raise_for_status()
        return Image.open(BytesIO(response.content))
    except Exception as e:
        print(f"Error downloading image: {e}")
        return None

def save_predictions_to_file(predictions: List[int], output_path: str) -> None:
    """Save predictions to a text file."""
    try:
        with open(output_path, 'w') as f:
            f.write(f"Predicted digits are: {', '.join(map(str, predictions))}\n")
    except Exception as e:
        print(f"Error saving predictions to file: {e}")

def main(image_url: str, model_name: str, output_path: str) -> None:
    try:
        predictor = MNISTPredictor(model_name)

        # Download image
        image = download_image(image_url)
        if image is None:
            raise Exception("Failed to download image")

        print(f"Image downloaded successfully.")

        # Predict digits
        digits = predictor.predict(image)
        if digits is not None:
            print(f"Predicted digits are: {digits}")

            # Save predictions to file
            save_predictions_to_file(digits, output_path)
            print(f"Predictions saved to {output_path}")
        else:
            print("Failed to predict digits.")
    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == "__main__":
    image_url = "https://miro.medium.com/v2/resize:fit:720/format:webp/1*w7pBsjI3t3ZP-4Gdog-JdQ.png"
    model_name = "0xnu/mnist-ocr"
    output_path = "predictions.txt"

    main(image_url, model_name, output_path)
```

### Copyright

(c) 2024 [Finbarrs Oketunji](https://finbarrs.eu). All Rights Reserved.