""" Main recognizer class for Gregg Shorthand Recognition """ import torch import torch.nn.functional as F from PIL import Image import numpy as np import os from typing import Union, List, Optional import torchvision.transforms as transforms from .models import Seq2SeqModel, ImageToTextModel from .config import Seq2SeqConfig, ImageToTextConfig class GreggRecognition: """ class for recognizing Gregg shorthand from images """ def __init__( self, model_type: str = "image_to_text", device: str = "auto", model_path: Optional[str] = None, config: Optional[Union[Seq2SeqConfig, ImageToTextConfig]] = None ): """ init GreggRecognition Args: model_type: "image_to_text" or "seq2seq" device: "auto", "cpu", or "cuda" model_path: Path to custom model file config: Custom configuration object """ self.model_type = model_type self.device = self._setup_device(device) # handle config if config is None: if model_type == "image_to_text": self.config = ImageToTextConfig() elif model_type == "seq2seq": self.config = Seq2SeqConfig() else: raise ValueError(f"Unknown model type: {model_type}") else: self.config = config # init image preprocessing self._setup_preprocessing() self.model = self._load_model(model_path) def _setup_device(self, device: str) -> torch.device: """Setup the computation device""" if device == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") else: return torch.device(device) def _setup_preprocessing(self): """Setup image preprocessing pipeline""" if self.model_type == "image_to_text": self.transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((self.config.image_height, self.config.image_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1] ]) else: # seq2seq self.transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((256, 256)), # Default size for seq2seq transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) def _load_model(self, model_path: Optional[str]) -> torch.nn.Module: """Load the model""" if self.model_type == "image_to_text": model = ImageToTextModel(self.config) elif self.model_type == "seq2seq": model = Seq2SeqModel(256, 256, self.config) else: raise ValueError(f"Unknown model type: {self.model_type}") # decide model path if model_path is None: package_dir = os.path.dirname(os.path.abspath(__file__)) if self.model_type == "image_to_text": model_path = os.path.join(package_dir, "models", "image_to_text_model.pth") elif self.model_type == "seq2seq": model_path = os.path.join(package_dir, "models", "seq2seq_model.pth") # load weights if model_path and os.path.exists(model_path): try: if hasattr(model, 'load_pretrained'): success = model.load_pretrained(model_path) if success: print(f"loaded model") else: print(f"failed to load model from {model_path}") else: checkpoint = torch.load(model_path, map_location=self.device) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) print(f"loaded model from {model_path}") except Exception as e: print(f"error loading model from {model_path}: {e}") else: if model_path: print(f"model file not found: {model_path}") model.to(self.device) model.eval() return model def _preprocess_image(self, image_path: str) -> torch.Tensor: """Preprocess a single image""" try: # load image image = Image.open(image_path) # apply transforms image_tensor = self.transform(image) # add batch dimension image_tensor = image_tensor.unsqueeze(0) # (1, C, H, W) return image_tensor.to(self.device) except Exception as e: raise ValueError(f"Error processing image {image_path}: {str(e)}") def recognize(self, image_path: str, **kwargs) -> str: """ Recognize shorthand from an image Args: image_path: Path to the image file **kwargs: Additional options for generation Returns: Recognized text string """ # Preprocess image image_tensor = self._preprocess_image(image_path) with torch.no_grad(): if self.model_type == "image_to_text": # image-to-text beam_size = kwargs.get('beam_size', 1) result = self.model.generate_text(image_tensor, beam_size=beam_size) return result if result else "" elif self.model_type == "seq2seq": # Sequence-to-sequence return self._generate_seq2seq(image_tensor, **kwargs) def _generate_seq2seq(self, image_tensor: torch.Tensor, **kwargs) -> str: """Generate text using seq2seq model""" max_length = kwargs.get('max_length', 50) temperature = kwargs.get('temperature', 1.0) # Create character mappings char_to_idx = {chr(i + ord('a')): i for i in range(26)} char_to_idx[' '] = 26 char_to_idx[''] = 27 idx_to_char = {v: k for k, v in char_to_idx.items()} # Start with empty context context = torch.zeros(1, 1, dtype=torch.long, device=self.device) generated_text = "" for _ in range(max_length): # Get predictions predictions = self.model(image_tensor, context) # Get last prediction last_pred = predictions[:, -1, :] # (1, vocab_size) # Apply temperature if temperature != 1.0: last_pred = last_pred / temperature # Sample next character probs = F.softmax(last_pred, dim=-1) next_char_idx = torch.multinomial(probs, 1).item() # Convert to character if next_char_idx in idx_to_char: char = idx_to_char[next_char_idx] if char == '': break generated_text += char # Update context next_char_tensor = torch.tensor([[next_char_idx]], device=self.device) context = torch.cat([context, next_char_tensor], dim=1) return generated_text def batch_recognize(self, image_paths: List[str], batch_size: int = 8, **kwargs) -> List[str]: """ Recognize shorthand from several images Args: image_paths: List of image file paths batch_size: Batch size for processing **kwargs: Additional options for generation Returns: List of recognized text strings """ results = [] for i in range(0, len(image_paths), batch_size): batch_paths = image_paths[i:i + batch_size] batch_results = [] for path in batch_paths: try: result = self.recognize(path, **kwargs) batch_results.append(result) except Exception as e: print(f"Error processing {path}: {str(e)}") batch_results.append("") results.extend(batch_results) return results def get_model_info(self) -> dict: """Get information about the loaded model""" num_params = sum(p.numel() for p in self.model.parameters()) return { "model_type": self.model_type, "device": str(self.device), "num_parameters": num_params, "config": self.config.__dict__ if hasattr(self.config, '__dict__') else str(self.config) }