Spaces:
Sleeping
Sleeping
""" | |
Main recognizer class for Gregg Shorthand Recognition | |
""" | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
import numpy as np | |
import os | |
from typing import Union, List, Optional | |
import torchvision.transforms as transforms | |
from .models import Seq2SeqModel, ImageToTextModel | |
from .config import Seq2SeqConfig, ImageToTextConfig | |
class GreggRecognition: | |
""" | |
class for recognizing Gregg shorthand from images | |
""" | |
def __init__( | |
self, | |
model_type: str = "image_to_text", | |
device: str = "auto", | |
model_path: Optional[str] = None, | |
config: Optional[Union[Seq2SeqConfig, ImageToTextConfig]] = None | |
): | |
""" | |
init GreggRecognition | |
Args: | |
model_type: "image_to_text" or "seq2seq" | |
device: "auto", "cpu", or "cuda" | |
model_path: Path to custom model file | |
config: Custom configuration object | |
""" | |
self.model_type = model_type | |
self.device = self._setup_device(device) | |
# handle config | |
if config is None: | |
if model_type == "image_to_text": | |
self.config = ImageToTextConfig() | |
elif model_type == "seq2seq": | |
self.config = Seq2SeqConfig() | |
else: | |
raise ValueError(f"Unknown model type: {model_type}") | |
else: | |
self.config = config | |
# init image preprocessing | |
self._setup_preprocessing() | |
self.model = self._load_model(model_path) | |
def _setup_device(self, device: str) -> torch.device: | |
"""Setup the computation device""" | |
if device == "auto": | |
return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
return torch.device(device) | |
def _setup_preprocessing(self): | |
"""Setup image preprocessing pipeline""" | |
if self.model_type == "image_to_text": | |
self.transform = transforms.Compose([ | |
transforms.Grayscale(num_output_channels=1), | |
transforms.Resize((self.config.image_height, self.config.image_width)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1] | |
]) | |
else: # seq2seq | |
self.transform = transforms.Compose([ | |
transforms.Grayscale(num_output_channels=1), | |
transforms.Resize((256, 256)), # Default size for seq2seq | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
def _load_model(self, model_path: Optional[str]) -> torch.nn.Module: | |
"""Load the model""" | |
if self.model_type == "image_to_text": | |
model = ImageToTextModel(self.config) | |
elif self.model_type == "seq2seq": | |
model = Seq2SeqModel(256, 256, self.config) | |
else: | |
raise ValueError(f"Unknown model type: {self.model_type}") | |
# decide model path | |
if model_path is None: | |
package_dir = os.path.dirname(os.path.abspath(__file__)) | |
if self.model_type == "image_to_text": | |
model_path = os.path.join(package_dir, "models", "image_to_text_model.pth") | |
elif self.model_type == "seq2seq": | |
model_path = os.path.join(package_dir, "models", "seq2seq_model.pth") | |
# load weights | |
if model_path and os.path.exists(model_path): | |
try: | |
if hasattr(model, 'load_pretrained'): | |
success = model.load_pretrained(model_path) | |
if success: | |
print(f"loaded model") | |
else: | |
print(f"failed to load model from {model_path}") | |
else: | |
checkpoint = torch.load(model_path, map_location=self.device) | |
if 'model_state_dict' in checkpoint: | |
model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
model.load_state_dict(checkpoint) | |
print(f"loaded model from {model_path}") | |
except Exception as e: | |
print(f"error loading model from {model_path}: {e}") | |
else: | |
if model_path: | |
print(f"model file not found: {model_path}") | |
model.to(self.device) | |
model.eval() | |
return model | |
def _preprocess_image(self, image_path: str) -> torch.Tensor: | |
"""Preprocess a single image""" | |
try: | |
# load image | |
image = Image.open(image_path) | |
# apply transforms | |
image_tensor = self.transform(image) | |
# add batch dimension | |
image_tensor = image_tensor.unsqueeze(0) # (1, C, H, W) | |
return image_tensor.to(self.device) | |
except Exception as e: | |
raise ValueError(f"Error processing image {image_path}: {str(e)}") | |
def recognize(self, image_path: str, **kwargs) -> str: | |
""" | |
Recognize shorthand from an image | |
Args: | |
image_path: Path to the image file | |
**kwargs: Additional options for generation | |
Returns: | |
Recognized text string | |
""" | |
# Preprocess image | |
image_tensor = self._preprocess_image(image_path) | |
with torch.no_grad(): | |
if self.model_type == "image_to_text": | |
# image-to-text | |
beam_size = kwargs.get('beam_size', 1) | |
result = self.model.generate_text(image_tensor, beam_size=beam_size) | |
return result if result else "" | |
elif self.model_type == "seq2seq": | |
# Sequence-to-sequence | |
return self._generate_seq2seq(image_tensor, **kwargs) | |
def _generate_seq2seq(self, image_tensor: torch.Tensor, **kwargs) -> str: | |
"""Generate text using seq2seq model""" | |
max_length = kwargs.get('max_length', 50) | |
temperature = kwargs.get('temperature', 1.0) | |
# Create character mappings | |
char_to_idx = {chr(i + ord('a')): i for i in range(26)} | |
char_to_idx[' '] = 26 | |
char_to_idx['<END>'] = 27 | |
idx_to_char = {v: k for k, v in char_to_idx.items()} | |
# Start with empty context | |
context = torch.zeros(1, 1, dtype=torch.long, device=self.device) | |
generated_text = "" | |
for _ in range(max_length): | |
# Get predictions | |
predictions = self.model(image_tensor, context) | |
# Get last prediction | |
last_pred = predictions[:, -1, :] # (1, vocab_size) | |
# Apply temperature | |
if temperature != 1.0: | |
last_pred = last_pred / temperature | |
# Sample next character | |
probs = F.softmax(last_pred, dim=-1) | |
next_char_idx = torch.multinomial(probs, 1).item() | |
# Convert to character | |
if next_char_idx in idx_to_char: | |
char = idx_to_char[next_char_idx] | |
if char == '<END>': | |
break | |
generated_text += char | |
# Update context | |
next_char_tensor = torch.tensor([[next_char_idx]], device=self.device) | |
context = torch.cat([context, next_char_tensor], dim=1) | |
return generated_text | |
def batch_recognize(self, image_paths: List[str], batch_size: int = 8, **kwargs) -> List[str]: | |
""" | |
Recognize shorthand from several images | |
Args: | |
image_paths: List of image file paths | |
batch_size: Batch size for processing | |
**kwargs: Additional options for generation | |
Returns: | |
List of recognized text strings | |
""" | |
results = [] | |
for i in range(0, len(image_paths), batch_size): | |
batch_paths = image_paths[i:i + batch_size] | |
batch_results = [] | |
for path in batch_paths: | |
try: | |
result = self.recognize(path, **kwargs) | |
batch_results.append(result) | |
except Exception as e: | |
print(f"Error processing {path}: {str(e)}") | |
batch_results.append("") | |
results.extend(batch_results) | |
return results | |
def get_model_info(self) -> dict: | |
"""Get information about the loaded model""" | |
num_params = sum(p.numel() for p in self.model.parameters()) | |
return { | |
"model_type": self.model_type, | |
"device": str(self.device), | |
"num_parameters": num_params, | |
"config": self.config.__dict__ if hasattr(self.config, '__dict__') else str(self.config) | |
} | |