File size: 5,630 Bytes
9e430b7 980f6f8 9e430b7 058f4c8 980f6f8 530b58b e99dff6 f5b9188 9facfab bcfe26b 730ff55 a45bda7 730ff55 a45bda7 730ff55 a45bda7 730ff55 a45bda7 730ff55 a45bda7 730ff55 a45bda7 730ff55 a45bda7 730ff55 a45bda7 730ff55 a45bda7 730ff55 a45bda7 730ff55 a45bda7 730ff55 f5b9188 7d5ba4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
---
license: apache-2.0
tags:
- vision
- image-classification
datasets:
- dmitva/the-mnist-database
inference: true
pipeline_tag: image-classification
widget:
- text: "Enter image URL"
example: https://miro.medium.com/v2/resize:fit:720/format:webp/1*w7pBsjI3t3ZP-4Gdog-JdQ.png
---
The MNIST OCR (Optical Character Recognition) model is a deep learning model trained to recognise and classify handwritten digits from 0 to 9. This model is trained on the MNIST dataset, which consists of 60,000 small square 28×28 pixel grayscale images of handwritten single digits, making it highly accurate for recognising written, isolated digits in a similar style to those found in the training set.

### Install Packages
```sh
pip install numpy opencv-python requests pillow transformers tensorflow
```
### Usage
```python
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np
import cv2
import requests
from PIL import Image
from io import BytesIO
from typing import List, Optional
from huggingface_hub import hf_hub_download
import tensorflow as tf
import pickle
class ImageTokenizer:
def __init__(self):
self.unique_pixels = set()
self.pixel_to_token = {}
self.token_to_pixel = {}
def fit(self, images):
for image in images:
self.unique_pixels.update(np.unique(image))
self.pixel_to_token = {pixel: i for i, pixel in enumerate(sorted(self.unique_pixels))}
self.token_to_pixel = {i: pixel for pixel, i in self.pixel_to_token.items()}
def tokenize(self, images):
return np.vectorize(self.pixel_to_token.get)(images)
def detokenize(self, tokens):
return np.vectorize(self.token_to_pixel.get)(tokens)
class MNISTPredictor:
def __init__(self, model_name):
# Download the model and tokenizer files
model_path = hf_hub_download(repo_id=model_name, filename="mnist_model.keras")
tokenizer_path = hf_hub_download(repo_id=model_name, filename="mnist_tokenizer.pkl")
# Load the model and tokenizer
self.model = keras.models.load_model(model_path)
with open(tokenizer_path, 'rb') as tokenizer_file:
self.tokenizer = pickle.load(tokenizer_file)
def extract_features(self, image: Image.Image) -> List[np.ndarray]:
"""Extract features from the image for multiple digits."""
# Convert to grayscale
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
# Apply Gaussian blur
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Apply adaptive thresholding
thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
# Find contours
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
digit_images = []
for contour in contours:
# Filter small contours
if cv2.contourArea(contour) > 50: # Adjust this threshold as needed
x, y, w, h = cv2.boundingRect(contour)
roi = thresh[y:y+h, x:x+w]
resized = cv2.resize(roi, (28, 28), interpolation=cv2.INTER_AREA)
digit_images.append(resized.reshape((28, 28, 1)).astype('float32') / 255)
return digit_images
def predict(self, image: Image.Image) -> Optional[List[int]]:
"""Predict digits in the image."""
try:
digit_images = self.extract_features(image)
tokenized_images = [self.tokenizer.tokenize(img) for img in digit_images]
predictions = self.model.predict(np.array(tokenized_images), verbose=0)
return np.argmax(predictions, axis=1).tolist()
except Exception as e:
print(f"Error during prediction: {e}")
return None
def download_image(url: str) -> Optional[Image.Image]:
"""Download an image from a URL."""
try:
response = requests.get(url)
response.raise_for_status()
return Image.open(BytesIO(response.content))
except Exception as e:
print(f"Error downloading image: {e}")
return None
def save_predictions_to_file(predictions: List[int], output_path: str) -> None:
"""Save predictions to a text file."""
try:
with open(output_path, 'w') as f:
f.write(f"Predicted digits are: {', '.join(map(str, predictions))}\n")
except Exception as e:
print(f"Error saving predictions to file: {e}")
def main(image_url: str, model_name: str, output_path: str) -> None:
try:
predictor = MNISTPredictor(model_name)
# Download image
image = download_image(image_url)
if image is None:
raise Exception("Failed to download image")
print(f"Image downloaded successfully.")
# Predict digits
digits = predictor.predict(image)
if digits is not None:
print(f"Predicted digits are: {digits}")
# Save predictions to file
save_predictions_to_file(digits, output_path)
print(f"Predictions saved to {output_path}")
else:
print("Failed to predict digits.")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
image_url = "https://miro.medium.com/v2/resize:fit:720/format:webp/1*w7pBsjI3t3ZP-4Gdog-JdQ.png"
model_name = "0xnu/mnist-ocr"
output_path = "predictions.txt"
main(image_url, model_name, output_path)
```
### Copyright
(c) 2024 [Finbarrs Oketunji](https://finbarrs.eu). All Rights Reserved.
|