a0a7's picture
add real model
e6769bb
"""
Main recognizer class for Gregg Shorthand Recognition
"""
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import os
from typing import Union, List, Optional
import torchvision.transforms as transforms
from .models import Seq2SeqModel, ImageToTextModel
from .config import Seq2SeqConfig, ImageToTextConfig
class GreggRecognition:
"""
class for recognizing Gregg shorthand from images
"""
def __init__(
self,
model_type: str = "image_to_text",
device: str = "auto",
model_path: Optional[str] = None,
config: Optional[Union[Seq2SeqConfig, ImageToTextConfig]] = None
):
"""
init GreggRecognition
Args:
model_type: "image_to_text" or "seq2seq"
device: "auto", "cpu", or "cuda"
model_path: Path to custom model file
config: Custom configuration object
"""
self.model_type = model_type
self.device = self._setup_device(device)
# handle config
if config is None:
if model_type == "image_to_text":
self.config = ImageToTextConfig()
elif model_type == "seq2seq":
self.config = Seq2SeqConfig()
else:
raise ValueError(f"Unknown model type: {model_type}")
else:
self.config = config
# init image preprocessing
self._setup_preprocessing()
self.model = self._load_model(model_path)
def _setup_device(self, device: str) -> torch.device:
"""Setup the computation device"""
if device == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
return torch.device(device)
def _setup_preprocessing(self):
"""Setup image preprocessing pipeline"""
if self.model_type == "image_to_text":
self.transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((self.config.image_height, self.config.image_width)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1]
])
else: # seq2seq
self.transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((256, 256)), # Default size for seq2seq
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
def _load_model(self, model_path: Optional[str]) -> torch.nn.Module:
"""Load the model"""
if self.model_type == "image_to_text":
model = ImageToTextModel(self.config)
elif self.model_type == "seq2seq":
model = Seq2SeqModel(256, 256, self.config)
else:
raise ValueError(f"Unknown model type: {self.model_type}")
# decide model path
if model_path is None:
package_dir = os.path.dirname(os.path.abspath(__file__))
if self.model_type == "image_to_text":
model_path = os.path.join(package_dir, "models", "image_to_text_model.pth")
elif self.model_type == "seq2seq":
model_path = os.path.join(package_dir, "models", "seq2seq_model.pth")
# load weights
if model_path and os.path.exists(model_path):
try:
if hasattr(model, 'load_pretrained'):
success = model.load_pretrained(model_path)
if success:
print(f"loaded model")
else:
print(f"failed to load model from {model_path}")
else:
checkpoint = torch.load(model_path, map_location=self.device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
print(f"loaded model from {model_path}")
except Exception as e:
print(f"error loading model from {model_path}: {e}")
else:
if model_path:
print(f"model file not found: {model_path}")
model.to(self.device)
model.eval()
return model
def _preprocess_image(self, image_path: str) -> torch.Tensor:
"""Preprocess a single image"""
try:
# load image
image = Image.open(image_path)
# apply transforms
image_tensor = self.transform(image)
# add batch dimension
image_tensor = image_tensor.unsqueeze(0) # (1, C, H, W)
return image_tensor.to(self.device)
except Exception as e:
raise ValueError(f"Error processing image {image_path}: {str(e)}")
def recognize(self, image_path: str, **kwargs) -> str:
"""
Recognize shorthand from an image
Args:
image_path: Path to the image file
**kwargs: Additional options for generation
Returns:
Recognized text string
"""
# Preprocess image
image_tensor = self._preprocess_image(image_path)
with torch.no_grad():
if self.model_type == "image_to_text":
# image-to-text
beam_size = kwargs.get('beam_size', 1)
result = self.model.generate_text(image_tensor, beam_size=beam_size)
return result if result else ""
elif self.model_type == "seq2seq":
# Sequence-to-sequence
return self._generate_seq2seq(image_tensor, **kwargs)
def _generate_seq2seq(self, image_tensor: torch.Tensor, **kwargs) -> str:
"""Generate text using seq2seq model"""
max_length = kwargs.get('max_length', 50)
temperature = kwargs.get('temperature', 1.0)
# Create character mappings
char_to_idx = {chr(i + ord('a')): i for i in range(26)}
char_to_idx[' '] = 26
char_to_idx['<END>'] = 27
idx_to_char = {v: k for k, v in char_to_idx.items()}
# Start with empty context
context = torch.zeros(1, 1, dtype=torch.long, device=self.device)
generated_text = ""
for _ in range(max_length):
# Get predictions
predictions = self.model(image_tensor, context)
# Get last prediction
last_pred = predictions[:, -1, :] # (1, vocab_size)
# Apply temperature
if temperature != 1.0:
last_pred = last_pred / temperature
# Sample next character
probs = F.softmax(last_pred, dim=-1)
next_char_idx = torch.multinomial(probs, 1).item()
# Convert to character
if next_char_idx in idx_to_char:
char = idx_to_char[next_char_idx]
if char == '<END>':
break
generated_text += char
# Update context
next_char_tensor = torch.tensor([[next_char_idx]], device=self.device)
context = torch.cat([context, next_char_tensor], dim=1)
return generated_text
def batch_recognize(self, image_paths: List[str], batch_size: int = 8, **kwargs) -> List[str]:
"""
Recognize shorthand from several images
Args:
image_paths: List of image file paths
batch_size: Batch size for processing
**kwargs: Additional options for generation
Returns:
List of recognized text strings
"""
results = []
for i in range(0, len(image_paths), batch_size):
batch_paths = image_paths[i:i + batch_size]
batch_results = []
for path in batch_paths:
try:
result = self.recognize(path, **kwargs)
batch_results.append(result)
except Exception as e:
print(f"Error processing {path}: {str(e)}")
batch_results.append("")
results.extend(batch_results)
return results
def get_model_info(self) -> dict:
"""Get information about the loaded model"""
num_params = sum(p.numel() for p in self.model.parameters())
return {
"model_type": self.model_type,
"device": str(self.device),
"num_parameters": num_params,
"config": self.config.__dict__ if hasattr(self.config, '__dict__') else str(self.config)
}