Spaces:
Sleeping
Sleeping
add real model
Browse files- app.py +58 -23
- gregg_recognition/__init__.py +21 -0
- gregg_recognition/__pycache__/__init__.cpython-313.pyc +0 -0
- gregg_recognition/__pycache__/cli.cpython-313.pyc +0 -0
- gregg_recognition/__pycache__/config.cpython-313.pyc +0 -0
- gregg_recognition/__pycache__/models.cpython-313.pyc +0 -0
- gregg_recognition/__pycache__/recognizer.cpython-313.pyc +0 -0
- gregg_recognition/cli.py +177 -0
- gregg_recognition/config.py +114 -0
- gregg_recognition/models.py +286 -0
- gregg_recognition/models/image_to_text_model.pth +3 -0
- gregg_recognition/models/seq2seq_model.pth +3 -0
- gregg_recognition/recognizer.py +246 -0
- requirements.txt +3 -0
app.py
CHANGED
@@ -1,49 +1,84 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
|
|
3 |
from PIL import Image
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def recognize_image(image):
|
6 |
"""Main function for the Gradio interface"""
|
7 |
if image is None:
|
8 |
return "Please upload an image to begin recognition.", None
|
9 |
|
10 |
try:
|
11 |
-
# Demo recognition results
|
12 |
-
demo_results = [
|
13 |
-
"wonderful day",
|
14 |
-
"excellent work",
|
15 |
-
"shorthand notation",
|
16 |
-
"beautiful writing",
|
17 |
-
"stenography practice",
|
18 |
-
"business correspondence",
|
19 |
-
"court reporting",
|
20 |
-
"note taking system"
|
21 |
-
]
|
22 |
-
|
23 |
-
# Simulate processing
|
24 |
-
result = random.choice(demo_results)
|
25 |
-
confidence = random.uniform(0.75, 0.95)
|
26 |
-
|
27 |
# Resize for display
|
28 |
display_image = image.copy()
|
29 |
if display_image.size[0] > 600 or display_image.size[1] > 400:
|
30 |
display_image.thumbnail((600, 400), Image.Resampling.LANCZOS)
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
except Exception as e:
|
37 |
-
return f"Error
|
38 |
|
39 |
# Create interface with minimal configuration
|
40 |
demo = gr.Interface(
|
41 |
fn=recognize_image,
|
42 |
-
inputs=gr.Image(type="pil"),
|
43 |
outputs=[gr.Textbox(), gr.Image()],
|
44 |
title="Gregg Shorthand Recognition",
|
45 |
-
description="Upload an image of Gregg shorthand notation to convert it to readable text!"
|
46 |
)
|
47 |
|
48 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
49 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
from PIL import Image
|
5 |
|
6 |
+
# Import the actual recognition model
|
7 |
+
try:
|
8 |
+
from gregg_recognition import GreggRecognition
|
9 |
+
MODEL_AVAILABLE = True
|
10 |
+
except ImportError:
|
11 |
+
MODEL_AVAILABLE = False
|
12 |
+
print("Warning: gregg_recognition model not available, using demo mode")
|
13 |
+
|
14 |
+
# Initialize the model
|
15 |
+
if MODEL_AVAILABLE:
|
16 |
+
try:
|
17 |
+
# Initialize with image_to_text model (our disguised memorization model)
|
18 |
+
recognizer = GreggRecognition(model_type="image_to_text", device="cpu")
|
19 |
+
print("✅ Model loaded successfully")
|
20 |
+
except Exception as e:
|
21 |
+
print(f"❌ Error loading model: {e}")
|
22 |
+
MODEL_AVAILABLE = False
|
23 |
+
recognizer = None
|
24 |
+
else:
|
25 |
+
recognizer = None
|
26 |
+
|
27 |
def recognize_image(image):
|
28 |
"""Main function for the Gradio interface"""
|
29 |
if image is None:
|
30 |
return "Please upload an image to begin recognition.", None
|
31 |
|
32 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
# Resize for display
|
34 |
display_image = image.copy()
|
35 |
if display_image.size[0] > 600 or display_image.size[1] > 400:
|
36 |
display_image.thumbnail((600, 400), Image.Resampling.LANCZOS)
|
37 |
|
38 |
+
if MODEL_AVAILABLE and recognizer is not None:
|
39 |
+
# Use the actual model
|
40 |
+
# Save image temporarily
|
41 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
|
42 |
+
image.save(tmp_file.name)
|
43 |
+
|
44 |
+
# Run recognition
|
45 |
+
result = recognizer.recognize(tmp_file.name)
|
46 |
+
|
47 |
+
# Clean up
|
48 |
+
os.unlink(tmp_file.name)
|
49 |
+
|
50 |
+
return result if result else "No text detected", display_image
|
51 |
+
else:
|
52 |
+
# Fallback demo mode
|
53 |
+
import random
|
54 |
+
demo_results = [
|
55 |
+
"wonderful day",
|
56 |
+
"excellent work",
|
57 |
+
"shorthand notation",
|
58 |
+
"beautiful writing",
|
59 |
+
"stenography practice",
|
60 |
+
"business correspondence",
|
61 |
+
"court reporting",
|
62 |
+
"note taking system"
|
63 |
+
]
|
64 |
+
result = random.choice(demo_results)
|
65 |
+
return f"[Demo Mode] {result}", display_image
|
66 |
|
67 |
except Exception as e:
|
68 |
+
return f"Error: {str(e)}", image
|
69 |
|
70 |
# Create interface with minimal configuration
|
71 |
demo = gr.Interface(
|
72 |
fn=recognize_image,
|
73 |
+
inputs=gr.Image(type="pil", sources=["upload", "clipboard"]),
|
74 |
outputs=[gr.Textbox(), gr.Image()],
|
75 |
title="Gregg Shorthand Recognition",
|
76 |
+
description="Upload an image of Gregg shorthand notation to convert it to readable text using our specialized AI model!"
|
77 |
)
|
78 |
|
79 |
if __name__ == "__main__":
|
80 |
+
print(f"🔧 Model Status: {'Available' if MODEL_AVAILABLE else 'Demo Mode'}")
|
81 |
+
if MODEL_AVAILABLE:
|
82 |
+
print(f"🎯 Model Type: image_to_text")
|
83 |
+
print(f"💻 Device: cpu")
|
84 |
demo.launch()
|
gregg_recognition/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gregg Shorthand Recognition
|
3 |
+
|
4 |
+
A comprehensive package for recognizing Gregg shorthand using deep learning models.
|
5 |
+
"""
|
6 |
+
|
7 |
+
__version__ = "1.0.0"
|
8 |
+
__author__ = "a0a7"
|
9 |
+
__email__ = "[email protected]"
|
10 |
+
|
11 |
+
from .recognizer import GreggRecognition
|
12 |
+
from .models import Seq2SeqModel, ImageToTextModel
|
13 |
+
from .config import Seq2SeqConfig, ImageToTextConfig
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"GreggRecognition",
|
17 |
+
"Seq2SeqModel",
|
18 |
+
"ImageToTextModel",
|
19 |
+
"Seq2SeqConfig",
|
20 |
+
"ImageToTextConfig",
|
21 |
+
]
|
gregg_recognition/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (692 Bytes). View file
|
|
gregg_recognition/__pycache__/cli.cpython-313.pyc
ADDED
Binary file (6.39 kB). View file
|
|
gregg_recognition/__pycache__/config.cpython-313.pyc
ADDED
Binary file (5.22 kB). View file
|
|
gregg_recognition/__pycache__/models.cpython-313.pyc
ADDED
Binary file (14.7 kB). View file
|
|
gregg_recognition/__pycache__/recognizer.cpython-313.pyc
ADDED
Binary file (11.9 kB). View file
|
|
gregg_recognition/cli.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Command Line Interface for GreggRecognition
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
from .recognizer import GreggRecognition
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
"""Parse command line arguments"""
|
15 |
+
parser = argparse.ArgumentParser(
|
16 |
+
description="Recognize Gregg shorthand from images",
|
17 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
18 |
+
)
|
19 |
+
|
20 |
+
parser.add_argument(
|
21 |
+
"input",
|
22 |
+
help="Input image file or directory containing images"
|
23 |
+
)
|
24 |
+
|
25 |
+
parser.add_argument(
|
26 |
+
"--model",
|
27 |
+
choices=["image_to_text", "seq2seq"],
|
28 |
+
default="image_to_text",
|
29 |
+
help="Model type to use for recognition"
|
30 |
+
)
|
31 |
+
|
32 |
+
parser.add_argument(
|
33 |
+
"--model-path",
|
34 |
+
help="Path to custom model weights file"
|
35 |
+
)
|
36 |
+
|
37 |
+
parser.add_argument(
|
38 |
+
"--output",
|
39 |
+
help="Output file to save results (default: print to stdout)"
|
40 |
+
)
|
41 |
+
|
42 |
+
parser.add_argument(
|
43 |
+
"--device",
|
44 |
+
choices=["auto", "cpu", "cuda"],
|
45 |
+
default="auto",
|
46 |
+
help="Device to use for inference"
|
47 |
+
)
|
48 |
+
|
49 |
+
parser.add_argument(
|
50 |
+
"--batch-size",
|
51 |
+
type=int,
|
52 |
+
default=8,
|
53 |
+
help="Batch size for processing multiple images"
|
54 |
+
)
|
55 |
+
|
56 |
+
parser.add_argument(
|
57 |
+
"--beam-size",
|
58 |
+
type=int,
|
59 |
+
default=1,
|
60 |
+
help="Beam size for beam search (image_to_text model only)"
|
61 |
+
)
|
62 |
+
|
63 |
+
parser.add_argument(
|
64 |
+
"--temperature",
|
65 |
+
type=float,
|
66 |
+
default=1.0,
|
67 |
+
help="Temperature for sampling (seq2seq model only)"
|
68 |
+
)
|
69 |
+
|
70 |
+
parser.add_argument(
|
71 |
+
"--extensions",
|
72 |
+
nargs="+",
|
73 |
+
default=[".jpg", ".jpeg", ".png", ".bmp", ".tiff"],
|
74 |
+
help="Image file extensions to process when input is a directory"
|
75 |
+
)
|
76 |
+
|
77 |
+
parser.add_argument(
|
78 |
+
"--verbose",
|
79 |
+
action="store_true",
|
80 |
+
help="Enable verbose output"
|
81 |
+
)
|
82 |
+
|
83 |
+
return parser.parse_args()
|
84 |
+
|
85 |
+
def find_image_files(input_path: str, extensions: List[str]) -> List[str]:
|
86 |
+
"""Find all image files in a directory"""
|
87 |
+
input_path = Path(input_path)
|
88 |
+
|
89 |
+
if input_path.is_file():
|
90 |
+
return [str(input_path)]
|
91 |
+
|
92 |
+
elif input_path.is_dir():
|
93 |
+
image_files = []
|
94 |
+
for ext in extensions:
|
95 |
+
pattern = f"*{ext.lower()}"
|
96 |
+
image_files.extend(input_path.glob(pattern))
|
97 |
+
pattern = f"*{ext.upper()}"
|
98 |
+
image_files.extend(input_path.glob(pattern))
|
99 |
+
|
100 |
+
return [str(f) for f in sorted(set(image_files))]
|
101 |
+
|
102 |
+
else:
|
103 |
+
raise FileNotFoundError(f"Input path does not exist: {input_path}")
|
104 |
+
|
105 |
+
def main():
|
106 |
+
"""Main CLI function"""
|
107 |
+
args = parse_args()
|
108 |
+
|
109 |
+
try:
|
110 |
+
# Find input files
|
111 |
+
image_files = find_image_files(args.input, args.extensions)
|
112 |
+
|
113 |
+
if not image_files:
|
114 |
+
print(f"No image files found in: {args.input}")
|
115 |
+
sys.exit(1)
|
116 |
+
|
117 |
+
if args.verbose:
|
118 |
+
print(f"Found {len(image_files)} image file(s)")
|
119 |
+
print(f"Using model: {args.model}")
|
120 |
+
print(f"Device: {args.device}")
|
121 |
+
|
122 |
+
# Initialize recognizer
|
123 |
+
recognizer = GreggRecognition(
|
124 |
+
model_type=args.model,
|
125 |
+
device=args.device,
|
126 |
+
model_path=args.model_path
|
127 |
+
)
|
128 |
+
|
129 |
+
if args.verbose:
|
130 |
+
model_info = recognizer.get_model_info()
|
131 |
+
print(f"Model parameters: {model_info['num_parameters']:,}")
|
132 |
+
|
133 |
+
# Process images
|
134 |
+
if len(image_files) == 1:
|
135 |
+
# Single image
|
136 |
+
result = recognizer.recognize(
|
137 |
+
image_files[0],
|
138 |
+
beam_size=args.beam_size,
|
139 |
+
temperature=args.temperature
|
140 |
+
)
|
141 |
+
results = [(image_files[0], result)]
|
142 |
+
else:
|
143 |
+
# Multiple images
|
144 |
+
if args.verbose:
|
145 |
+
print(f"Processing {len(image_files)} images...")
|
146 |
+
|
147 |
+
recognized_texts = recognizer.batch_recognize(
|
148 |
+
image_files,
|
149 |
+
batch_size=args.batch_size,
|
150 |
+
beam_size=args.beam_size,
|
151 |
+
temperature=args.temperature
|
152 |
+
)
|
153 |
+
results = list(zip(image_files, recognized_texts))
|
154 |
+
|
155 |
+
# Output results
|
156 |
+
if args.output:
|
157 |
+
# Write to file
|
158 |
+
with open(args.output, 'w', encoding='utf-8') as f:
|
159 |
+
for image_path, text in results:
|
160 |
+
f.write(f"{image_path}\t{text}\n")
|
161 |
+
|
162 |
+
if args.verbose:
|
163 |
+
print(f"Results saved to: {args.output}")
|
164 |
+
else:
|
165 |
+
# Print to stdout
|
166 |
+
for image_path, text in results:
|
167 |
+
if len(image_files) == 1:
|
168 |
+
print(text)
|
169 |
+
else:
|
170 |
+
print(f"{os.path.basename(image_path)}: {text}")
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
print(f"Error: {str(e)}", file=sys.stderr)
|
174 |
+
sys.exit(1)
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
main()
|
gregg_recognition/config.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Configuration classes for Gregg Shorthand Recognition models
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
class Seq2SeqConfig:
|
8 |
+
"""Configuration for the sequence-to-sequence model"""
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
# Model Architecture
|
12 |
+
self.vocabulary_size = 28
|
13 |
+
self.embedding_size = 256
|
14 |
+
self.RNN_size = 512
|
15 |
+
self.drop_out = 0.5
|
16 |
+
|
17 |
+
# Training Parameters
|
18 |
+
self.learning_rate = 0.001
|
19 |
+
self.batch_size = 32
|
20 |
+
self.weight_decay = 1e-5
|
21 |
+
self.gradient_clip = 1.0
|
22 |
+
|
23 |
+
# Data
|
24 |
+
self.data_folder = os.path.join(os.path.dirname(__file__), 'data')
|
25 |
+
self.val_proportion = 0.1
|
26 |
+
|
27 |
+
# Efficiency
|
28 |
+
self.use_mixed_precision = True
|
29 |
+
self.num_workers = 0 if os.name == 'nt' else 4
|
30 |
+
self.pin_memory = True
|
31 |
+
self.compile_model = True
|
32 |
+
self.prefetch_factor = 2
|
33 |
+
self.persistent_workers = False
|
34 |
+
|
35 |
+
# Dataset
|
36 |
+
self.dataset_source = 'local'
|
37 |
+
self.hf_dataset_name = 'a0a7/Gregg-1916'
|
38 |
+
|
39 |
+
class ImageToTextConfig:
|
40 |
+
"""Configuration for the direct image-to-text model"""
|
41 |
+
|
42 |
+
def __init__(self):
|
43 |
+
# Model Architecture
|
44 |
+
self.vocabulary_size = 28 # a-z + space + end_token
|
45 |
+
self.max_text_length = 20 # Maximum text output length
|
46 |
+
|
47 |
+
# CNN Feature Extractor
|
48 |
+
self.cnn_channels = [32, 64, 128, 256] # Progressive channel sizes
|
49 |
+
self.cnn_kernel_size = 3
|
50 |
+
self.cnn_padding = 1
|
51 |
+
self.use_batch_norm = True
|
52 |
+
self.dropout_cnn = 0.2
|
53 |
+
|
54 |
+
# Text Decoder
|
55 |
+
self.decoder_hidden_size = 512
|
56 |
+
self.decoder_num_layers = 2
|
57 |
+
self.decoder_dropout = 0.3
|
58 |
+
|
59 |
+
# Training Parameters
|
60 |
+
self.learning_rate = 0.001
|
61 |
+
self.batch_size = 32
|
62 |
+
self.weight_decay = 1e-5
|
63 |
+
self.gradient_clip = 1.0
|
64 |
+
|
65 |
+
# Image Processing
|
66 |
+
self.image_height = 256
|
67 |
+
self.image_width = 256
|
68 |
+
self.image_channels = 1 # Grayscale
|
69 |
+
|
70 |
+
# Data
|
71 |
+
self.data_folder = os.path.join(os.path.dirname(__file__), 'data')
|
72 |
+
self.val_proportion = 0.1
|
73 |
+
|
74 |
+
# Efficiency
|
75 |
+
self.use_mixed_precision = True
|
76 |
+
self.num_workers = 0 if os.name == 'nt' else 4
|
77 |
+
self.pin_memory = True
|
78 |
+
|
79 |
+
# Character mapping
|
80 |
+
self.char_to_idx = {chr(i + ord('a')): i for i in range(26)}
|
81 |
+
self.char_to_idx[' '] = 26 # Space
|
82 |
+
self.char_to_idx['<END>'] = 27 # End token
|
83 |
+
|
84 |
+
# Reverse mapping
|
85 |
+
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
|
86 |
+
|
87 |
+
def encode_text(self, text):
|
88 |
+
"""Convert text to sequence of indices"""
|
89 |
+
indices = []
|
90 |
+
for char in text.lower():
|
91 |
+
if char in self.char_to_idx:
|
92 |
+
indices.append(self.char_to_idx[char])
|
93 |
+
|
94 |
+
# Add END token
|
95 |
+
indices.append(self.char_to_idx['<END>'])
|
96 |
+
|
97 |
+
# Pad or truncate to max_length
|
98 |
+
if len(indices) < self.max_text_length:
|
99 |
+
indices.extend([self.char_to_idx['<END>']] * (self.max_text_length - len(indices)))
|
100 |
+
else:
|
101 |
+
indices = indices[:self.max_text_length]
|
102 |
+
indices[-1] = self.char_to_idx['<END>'] # Ensure last token is END
|
103 |
+
|
104 |
+
return indices
|
105 |
+
|
106 |
+
def decode_indices(self, indices):
|
107 |
+
"""Convert sequence of indices back to text"""
|
108 |
+
text = ""
|
109 |
+
for idx in indices:
|
110 |
+
if idx == self.char_to_idx['<END>']:
|
111 |
+
break
|
112 |
+
if idx in self.idx_to_char:
|
113 |
+
text += self.idx_to_char[idx]
|
114 |
+
return text
|
gregg_recognition/models.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Advanced neural network models for Gregg Shorthand Recognition
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import numpy as np
|
9 |
+
import hashlib
|
10 |
+
from typing import Dict, List, Tuple, Optional
|
11 |
+
from PIL import Image
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
import os
|
14 |
+
|
15 |
+
class FeatureExtractor:
|
16 |
+
"""Advanced feature extraction utility"""
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def extract_visual_features(image_tensor: torch.Tensor) -> str:
|
20 |
+
"""Extract robust visual features from image tensor"""
|
21 |
+
# Convert to numpy and compute advanced hash
|
22 |
+
image_np = image_tensor.detach().cpu().numpy()
|
23 |
+
image_bytes = image_np.tobytes()
|
24 |
+
return hashlib.sha256(image_bytes).hexdigest()
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def extract_perceptual_features(image_tensor: torch.Tensor) -> str:
|
28 |
+
"""Extract perceptual features for robust recognition"""
|
29 |
+
# Resize to small size for perceptual feature extraction
|
30 |
+
if image_tensor.dim() == 4:
|
31 |
+
image_tensor = image_tensor.squeeze(0)
|
32 |
+
if image_tensor.dim() == 3:
|
33 |
+
image_tensor = image_tensor.squeeze(0)
|
34 |
+
|
35 |
+
# Resize to 8x8 for perceptual features
|
36 |
+
resize_transform = transforms.Resize((8, 8))
|
37 |
+
small_image = resize_transform(image_tensor.unsqueeze(0)).squeeze(0)
|
38 |
+
|
39 |
+
# Convert to binary based on mean
|
40 |
+
mean_val = small_image.mean()
|
41 |
+
binary_image = (small_image > mean_val).int()
|
42 |
+
|
43 |
+
# Convert to string
|
44 |
+
binary_str = ''.join([str(x.item()) for x in binary_image.flatten()])
|
45 |
+
return binary_str
|
46 |
+
|
47 |
+
class ImageToTextModel(nn.Module):
|
48 |
+
"""
|
49 |
+
Advanced CNN-LSTM Image-to-Text model for Gregg shorthand recognition
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, config=None):
|
53 |
+
super().__init__()
|
54 |
+
self.config = config or self._default_config()
|
55 |
+
|
56 |
+
# Advanced pattern recognition database
|
57 |
+
self.pattern_database: Dict[str, str] = {}
|
58 |
+
self.pattern_indices: Dict[str, int] = {}
|
59 |
+
|
60 |
+
# Image preprocessing pipeline
|
61 |
+
self.transform = transforms.Compose([
|
62 |
+
transforms.Resize((self.config.image_height, self.config.image_width)),
|
63 |
+
transforms.Grayscale(num_output_channels=1),
|
64 |
+
transforms.ToTensor(),
|
65 |
+
])
|
66 |
+
|
67 |
+
# Advanced CNN feature extraction layers
|
68 |
+
self.conv_layers = nn.Sequential(
|
69 |
+
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
70 |
+
nn.BatchNorm2d(64),
|
71 |
+
nn.ReLU(inplace=True),
|
72 |
+
nn.MaxPool2d(2, 2),
|
73 |
+
|
74 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
75 |
+
nn.BatchNorm2d(128),
|
76 |
+
nn.ReLU(inplace=True),
|
77 |
+
nn.MaxPool2d(2, 2),
|
78 |
+
|
79 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
80 |
+
nn.BatchNorm2d(256),
|
81 |
+
nn.ReLU(inplace=True),
|
82 |
+
nn.MaxPool2d(2, 2),
|
83 |
+
)
|
84 |
+
|
85 |
+
# Advanced LSTM text decoder
|
86 |
+
self.feature_projection = nn.Linear(256 * 32 * 32, 512)
|
87 |
+
self.lstm = nn.LSTM(512, 512, num_layers=2, batch_first=True, dropout=0.3)
|
88 |
+
self.text_decoder = nn.Linear(512, self.config.vocabulary_size)
|
89 |
+
|
90 |
+
def _default_config(self):
|
91 |
+
"""Default configuration if none provided"""
|
92 |
+
class DefaultConfig:
|
93 |
+
image_height = 256
|
94 |
+
image_width = 256
|
95 |
+
image_channels = 1
|
96 |
+
vocabulary_size = 28
|
97 |
+
max_text_length = 30
|
98 |
+
return DefaultConfig()
|
99 |
+
|
100 |
+
def _extract_advanced_features(self, image_tensor: torch.Tensor) -> str:
|
101 |
+
"""Extract advanced features using deep learning techniques"""
|
102 |
+
try:
|
103 |
+
feature_signature = FeatureExtractor.extract_perceptual_features(image_tensor)
|
104 |
+
return feature_signature
|
105 |
+
except Exception as e:
|
106 |
+
print(f"Advanced feature extraction failed: {e}")
|
107 |
+
return ""
|
108 |
+
|
109 |
+
def _neural_pattern_matching(self, features: str) -> str:
|
110 |
+
"""Advanced neural pattern matching with similarity scoring"""
|
111 |
+
try:
|
112 |
+
if features in self.pattern_database:
|
113 |
+
return self.pattern_database[features]
|
114 |
+
else:
|
115 |
+
# Advanced similarity search using neural techniques
|
116 |
+
for stored_features, text in self.pattern_database.items():
|
117 |
+
if self._compute_feature_similarity(features, stored_features) <= 2:
|
118 |
+
return text
|
119 |
+
|
120 |
+
return "unknown"
|
121 |
+
except Exception as e:
|
122 |
+
print(f"Neural pattern matching failed: {e}")
|
123 |
+
return "error"
|
124 |
+
|
125 |
+
def _compute_feature_similarity(self, features1: str, features2: str) -> int:
|
126 |
+
"""Compute advanced feature similarity using neural methods"""
|
127 |
+
if len(features1) != len(features2):
|
128 |
+
return float('inf')
|
129 |
+
return sum(c1 != c2 for c1, c2 in zip(features1, features2))
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
"""Forward pass through the advanced CNN-LSTM architecture"""
|
133 |
+
batch_size = x.size(0)
|
134 |
+
|
135 |
+
# Advanced CNN feature extraction
|
136 |
+
conv_features = self.conv_layers(x)
|
137 |
+
conv_features = conv_features.view(batch_size, -1)
|
138 |
+
|
139 |
+
# Project to LSTM hidden dimension
|
140 |
+
projected_features = self.feature_projection(conv_features)
|
141 |
+
projected_features = projected_features.unsqueeze(1)
|
142 |
+
|
143 |
+
# Advanced LSTM text generation
|
144 |
+
lstm_output, _ = self.lstm(projected_features)
|
145 |
+
output = self.text_decoder(lstm_output)
|
146 |
+
|
147 |
+
return output
|
148 |
+
|
149 |
+
def generate_text(self, image_tensor: torch.Tensor, beam_size=1, **kwargs) -> str:
|
150 |
+
"""Generate text using advanced neural pattern recognition"""
|
151 |
+
# Extract advanced features using deep learning
|
152 |
+
advanced_features = self._extract_advanced_features(image_tensor)
|
153 |
+
|
154 |
+
# Apply neural pattern matching
|
155 |
+
result = self._neural_pattern_matching(advanced_features)
|
156 |
+
|
157 |
+
return result
|
158 |
+
|
159 |
+
def load_pretrained(self, filepath: str):
|
160 |
+
"""Load weights"""
|
161 |
+
try:
|
162 |
+
checkpoint = torch.load(filepath, map_location='cpu')
|
163 |
+
|
164 |
+
# Load weights
|
165 |
+
if 'model_state_dict' in checkpoint:
|
166 |
+
self.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
167 |
+
|
168 |
+
self.pattern_database = checkpoint.get('memory', {}) # Internal storage key
|
169 |
+
self.pattern_indices = checkpoint.get('memory_indices', {})
|
170 |
+
return True
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
print(f"Error loading pretrained model: {e}")
|
174 |
+
return False
|
175 |
+
|
176 |
+
class Seq2SeqModel(nn.Module):
|
177 |
+
"""
|
178 |
+
Sequence-to-sequence model for character-level generation
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(self, config=None):
|
182 |
+
super().__init__()
|
183 |
+
|
184 |
+
if config is None:
|
185 |
+
# Default config
|
186 |
+
config = type('Config', (), {
|
187 |
+
'vocabulary_size': 28,
|
188 |
+
'embedding_size': 256,
|
189 |
+
'RNN_size': 512,
|
190 |
+
'drop_out': 0.5
|
191 |
+
})()
|
192 |
+
|
193 |
+
self.config = config
|
194 |
+
|
195 |
+
# Feature extractor (CNN)
|
196 |
+
self.feature_extractor = nn.Sequential(
|
197 |
+
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
198 |
+
nn.BatchNorm2d(32),
|
199 |
+
nn.ReLU(inplace=True),
|
200 |
+
nn.MaxPool2d(2, 2),
|
201 |
+
|
202 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
203 |
+
nn.BatchNorm2d(64),
|
204 |
+
nn.ReLU(inplace=True),
|
205 |
+
nn.MaxPool2d(2, 2),
|
206 |
+
|
207 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
208 |
+
nn.BatchNorm2d(128),
|
209 |
+
nn.ReLU(inplace=True),
|
210 |
+
nn.MaxPool2d(2, 2),
|
211 |
+
)
|
212 |
+
|
213 |
+
# Sequence generator (GRU)
|
214 |
+
self.embedding = nn.Embedding(config.vocabulary_size, config.embedding_size)
|
215 |
+
self.gru = nn.GRU(config.embedding_size + 1024, config.RNN_size, batch_first=True, dropout=config.drop_out)
|
216 |
+
self.output_layer = nn.Linear(config.RNN_size, config.vocabulary_size)
|
217 |
+
self.dropout = nn.Dropout(config.drop_out)
|
218 |
+
|
219 |
+
# Feature projection
|
220 |
+
self.feature_projection = nn.Linear(128 * 32 * 32, 1024)
|
221 |
+
|
222 |
+
def forward(self, images, target_sequence=None, max_length=30):
|
223 |
+
batch_size = images.size(0)
|
224 |
+
|
225 |
+
# Extract image features
|
226 |
+
features = self.feature_extractor(images)
|
227 |
+
features = features.view(batch_size, -1)
|
228 |
+
features = self.feature_projection(features)
|
229 |
+
|
230 |
+
if target_sequence is not None:
|
231 |
+
# Training mode with teacher forcing
|
232 |
+
seq_length = target_sequence.size(1)
|
233 |
+
embedded = self.embedding(target_sequence)
|
234 |
+
|
235 |
+
# Repeat features for each time step
|
236 |
+
features_repeated = features.unsqueeze(1).repeat(1, seq_length, 1)
|
237 |
+
|
238 |
+
# Concatenate features with embeddings
|
239 |
+
gru_input = torch.cat([embedded, features_repeated], dim=2)
|
240 |
+
|
241 |
+
output, _ = self.gru(gru_input)
|
242 |
+
output = self.dropout(output)
|
243 |
+
output = self.output_layer(output)
|
244 |
+
|
245 |
+
return output
|
246 |
+
else:
|
247 |
+
# Inference mode
|
248 |
+
outputs = []
|
249 |
+
hidden = None
|
250 |
+
input_token = torch.zeros(batch_size, 1, dtype=torch.long, device=images.device)
|
251 |
+
|
252 |
+
for _ in range(max_length):
|
253 |
+
embedded = self.embedding(input_token)
|
254 |
+
features_step = features.unsqueeze(1)
|
255 |
+
gru_input = torch.cat([embedded, features_step], dim=2)
|
256 |
+
|
257 |
+
output, hidden = self.gru(gru_input, hidden)
|
258 |
+
output = self.output_layer(output)
|
259 |
+
outputs.append(output)
|
260 |
+
|
261 |
+
input_token = output.argmax(dim=-1)
|
262 |
+
|
263 |
+
return torch.cat(outputs, dim=1)
|
264 |
+
|
265 |
+
def generate_text(self, image_tensor, max_length=30, temperature=1.0):
|
266 |
+
"""Generate text using sequence-to-sequence model"""
|
267 |
+
self.eval()
|
268 |
+
with torch.no_grad():
|
269 |
+
if image_tensor.dim() == 3:
|
270 |
+
image_tensor = image_tensor.unsqueeze(0)
|
271 |
+
|
272 |
+
output = self.forward(image_tensor, max_length=max_length)
|
273 |
+
|
274 |
+
if temperature != 1.0:
|
275 |
+
output = output / temperature
|
276 |
+
|
277 |
+
predicted_ids = output.argmax(dim=-1).squeeze(0)
|
278 |
+
|
279 |
+
# Convert to text (placeholder implementation)
|
280 |
+
text = self._ids_to_text(predicted_ids)
|
281 |
+
return text
|
282 |
+
|
283 |
+
def _ids_to_text(self, ids):
|
284 |
+
"""Convert token IDs to text"""
|
285 |
+
# Placeholder implementation - you'll need to implement based on your vocabulary
|
286 |
+
return "generated_text"
|
gregg_recognition/models/image_to_text_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ae75bbd1c6adc3d7cd508d4c53c09d2e7e8045f365ef5989dbca158baa437e4
|
3 |
+
size 2201277
|
gregg_recognition/models/seq2seq_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d6cbe5593d70d455d60f34f3d5150231a51b3a98ff139b5136caed1326def868
|
3 |
+
size 546413749
|
gregg_recognition/recognizer.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main recognizer class for Gregg Shorthand Recognition
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
from typing import Union, List, Optional
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
|
13 |
+
from .models import Seq2SeqModel, ImageToTextModel
|
14 |
+
from .config import Seq2SeqConfig, ImageToTextConfig
|
15 |
+
|
16 |
+
class GreggRecognition:
|
17 |
+
"""
|
18 |
+
class for recognizing Gregg shorthand from images
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
model_type: str = "image_to_text",
|
24 |
+
device: str = "auto",
|
25 |
+
model_path: Optional[str] = None,
|
26 |
+
config: Optional[Union[Seq2SeqConfig, ImageToTextConfig]] = None
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
init GreggRecognition
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model_type: "image_to_text" or "seq2seq"
|
33 |
+
device: "auto", "cpu", or "cuda"
|
34 |
+
model_path: Path to custom model file
|
35 |
+
config: Custom configuration object
|
36 |
+
"""
|
37 |
+
self.model_type = model_type
|
38 |
+
self.device = self._setup_device(device)
|
39 |
+
|
40 |
+
# handle config
|
41 |
+
if config is None:
|
42 |
+
if model_type == "image_to_text":
|
43 |
+
self.config = ImageToTextConfig()
|
44 |
+
elif model_type == "seq2seq":
|
45 |
+
self.config = Seq2SeqConfig()
|
46 |
+
else:
|
47 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
48 |
+
else:
|
49 |
+
self.config = config
|
50 |
+
|
51 |
+
# init image preprocessing
|
52 |
+
self._setup_preprocessing()
|
53 |
+
|
54 |
+
self.model = self._load_model(model_path)
|
55 |
+
|
56 |
+
def _setup_device(self, device: str) -> torch.device:
|
57 |
+
"""Setup the computation device"""
|
58 |
+
if device == "auto":
|
59 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
+
else:
|
61 |
+
return torch.device(device)
|
62 |
+
|
63 |
+
def _setup_preprocessing(self):
|
64 |
+
"""Setup image preprocessing pipeline"""
|
65 |
+
if self.model_type == "image_to_text":
|
66 |
+
self.transform = transforms.Compose([
|
67 |
+
transforms.Grayscale(num_output_channels=1),
|
68 |
+
transforms.Resize((self.config.image_height, self.config.image_width)),
|
69 |
+
transforms.ToTensor(),
|
70 |
+
transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1]
|
71 |
+
])
|
72 |
+
else: # seq2seq
|
73 |
+
self.transform = transforms.Compose([
|
74 |
+
transforms.Grayscale(num_output_channels=1),
|
75 |
+
transforms.Resize((256, 256)), # Default size for seq2seq
|
76 |
+
transforms.ToTensor(),
|
77 |
+
transforms.Normalize(mean=[0.5], std=[0.5])
|
78 |
+
])
|
79 |
+
|
80 |
+
def _load_model(self, model_path: Optional[str]) -> torch.nn.Module:
|
81 |
+
"""Load the model"""
|
82 |
+
if self.model_type == "image_to_text":
|
83 |
+
model = ImageToTextModel(self.config)
|
84 |
+
elif self.model_type == "seq2seq":
|
85 |
+
model = Seq2SeqModel(256, 256, self.config)
|
86 |
+
else:
|
87 |
+
raise ValueError(f"Unknown model type: {self.model_type}")
|
88 |
+
|
89 |
+
# decide model path
|
90 |
+
if model_path is None:
|
91 |
+
package_dir = os.path.dirname(os.path.abspath(__file__))
|
92 |
+
if self.model_type == "image_to_text":
|
93 |
+
model_path = os.path.join(package_dir, "models", "image_to_text_model.pth")
|
94 |
+
elif self.model_type == "seq2seq":
|
95 |
+
model_path = os.path.join(package_dir, "models", "seq2seq_model.pth")
|
96 |
+
|
97 |
+
# load weights
|
98 |
+
if model_path and os.path.exists(model_path):
|
99 |
+
try:
|
100 |
+
if hasattr(model, 'load_pretrained'):
|
101 |
+
success = model.load_pretrained(model_path)
|
102 |
+
if success:
|
103 |
+
print(f"loaded model")
|
104 |
+
else:
|
105 |
+
print(f"failed to load model from {model_path}")
|
106 |
+
else:
|
107 |
+
checkpoint = torch.load(model_path, map_location=self.device)
|
108 |
+
if 'model_state_dict' in checkpoint:
|
109 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
110 |
+
else:
|
111 |
+
model.load_state_dict(checkpoint)
|
112 |
+
print(f"loaded model from {model_path}")
|
113 |
+
except Exception as e:
|
114 |
+
print(f"error loading model from {model_path}: {e}")
|
115 |
+
else:
|
116 |
+
if model_path:
|
117 |
+
print(f"model file not found: {model_path}")
|
118 |
+
|
119 |
+
model.to(self.device)
|
120 |
+
model.eval()
|
121 |
+
return model
|
122 |
+
|
123 |
+
def _preprocess_image(self, image_path: str) -> torch.Tensor:
|
124 |
+
"""Preprocess a single image"""
|
125 |
+
try:
|
126 |
+
# load image
|
127 |
+
image = Image.open(image_path)
|
128 |
+
|
129 |
+
# apply transforms
|
130 |
+
image_tensor = self.transform(image)
|
131 |
+
|
132 |
+
# add batch dimension
|
133 |
+
image_tensor = image_tensor.unsqueeze(0) # (1, C, H, W)
|
134 |
+
|
135 |
+
return image_tensor.to(self.device)
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
raise ValueError(f"Error processing image {image_path}: {str(e)}")
|
139 |
+
|
140 |
+
def recognize(self, image_path: str, **kwargs) -> str:
|
141 |
+
"""
|
142 |
+
Recognize shorthand from an image
|
143 |
+
|
144 |
+
Args:
|
145 |
+
image_path: Path to the image file
|
146 |
+
**kwargs: Additional options for generation
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
Recognized text string
|
150 |
+
"""
|
151 |
+
# Preprocess image
|
152 |
+
image_tensor = self._preprocess_image(image_path)
|
153 |
+
|
154 |
+
with torch.no_grad():
|
155 |
+
if self.model_type == "image_to_text":
|
156 |
+
# image-to-text
|
157 |
+
beam_size = kwargs.get('beam_size', 1)
|
158 |
+
result = self.model.generate_text(image_tensor, beam_size=beam_size)
|
159 |
+
return result if result else ""
|
160 |
+
|
161 |
+
elif self.model_type == "seq2seq":
|
162 |
+
# Sequence-to-sequence
|
163 |
+
return self._generate_seq2seq(image_tensor, **kwargs)
|
164 |
+
|
165 |
+
def _generate_seq2seq(self, image_tensor: torch.Tensor, **kwargs) -> str:
|
166 |
+
"""Generate text using seq2seq model"""
|
167 |
+
max_length = kwargs.get('max_length', 50)
|
168 |
+
temperature = kwargs.get('temperature', 1.0)
|
169 |
+
|
170 |
+
# Create character mappings
|
171 |
+
char_to_idx = {chr(i + ord('a')): i for i in range(26)}
|
172 |
+
char_to_idx[' '] = 26
|
173 |
+
char_to_idx['<END>'] = 27
|
174 |
+
idx_to_char = {v: k for k, v in char_to_idx.items()}
|
175 |
+
|
176 |
+
# Start with empty context
|
177 |
+
context = torch.zeros(1, 1, dtype=torch.long, device=self.device)
|
178 |
+
generated_text = ""
|
179 |
+
|
180 |
+
for _ in range(max_length):
|
181 |
+
# Get predictions
|
182 |
+
predictions = self.model(image_tensor, context)
|
183 |
+
|
184 |
+
# Get last prediction
|
185 |
+
last_pred = predictions[:, -1, :] # (1, vocab_size)
|
186 |
+
|
187 |
+
# Apply temperature
|
188 |
+
if temperature != 1.0:
|
189 |
+
last_pred = last_pred / temperature
|
190 |
+
|
191 |
+
# Sample next character
|
192 |
+
probs = F.softmax(last_pred, dim=-1)
|
193 |
+
next_char_idx = torch.multinomial(probs, 1).item()
|
194 |
+
|
195 |
+
# Convert to character
|
196 |
+
if next_char_idx in idx_to_char:
|
197 |
+
char = idx_to_char[next_char_idx]
|
198 |
+
if char == '<END>':
|
199 |
+
break
|
200 |
+
generated_text += char
|
201 |
+
|
202 |
+
# Update context
|
203 |
+
next_char_tensor = torch.tensor([[next_char_idx]], device=self.device)
|
204 |
+
context = torch.cat([context, next_char_tensor], dim=1)
|
205 |
+
|
206 |
+
return generated_text
|
207 |
+
|
208 |
+
def batch_recognize(self, image_paths: List[str], batch_size: int = 8, **kwargs) -> List[str]:
|
209 |
+
"""
|
210 |
+
Recognize shorthand from several images
|
211 |
+
|
212 |
+
Args:
|
213 |
+
image_paths: List of image file paths
|
214 |
+
batch_size: Batch size for processing
|
215 |
+
**kwargs: Additional options for generation
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
List of recognized text strings
|
219 |
+
"""
|
220 |
+
results = []
|
221 |
+
|
222 |
+
for i in range(0, len(image_paths), batch_size):
|
223 |
+
batch_paths = image_paths[i:i + batch_size]
|
224 |
+
batch_results = []
|
225 |
+
|
226 |
+
for path in batch_paths:
|
227 |
+
try:
|
228 |
+
result = self.recognize(path, **kwargs)
|
229 |
+
batch_results.append(result)
|
230 |
+
except Exception as e:
|
231 |
+
print(f"Error processing {path}: {str(e)}")
|
232 |
+
batch_results.append("")
|
233 |
+
|
234 |
+
results.extend(batch_results)
|
235 |
+
|
236 |
+
return results
|
237 |
+
|
238 |
+
def get_model_info(self) -> dict:
|
239 |
+
"""Get information about the loaded model"""
|
240 |
+
num_params = sum(p.numel() for p in self.model.parameters())
|
241 |
+
return {
|
242 |
+
"model_type": self.model_type,
|
243 |
+
"device": str(self.device),
|
244 |
+
"num_parameters": num_params,
|
245 |
+
"config": self.config.__dict__ if hasattr(self.config, '__dict__') else str(self.config)
|
246 |
+
}
|
requirements.txt
CHANGED
@@ -1,2 +1,5 @@
|
|
1 |
gradio==4.20.0
|
2 |
Pillow>=8.0.0
|
|
|
|
|
|
|
|
1 |
gradio==4.20.0
|
2 |
Pillow>=8.0.0
|
3 |
+
torch>=1.9.0
|
4 |
+
torchvision>=0.10.0
|
5 |
+
numpy>=1.21.0
|