a0a7's picture
add real model
e6769bb
"""
Configuration classes for Gregg Shorthand Recognition models
"""
import os
class Seq2SeqConfig:
"""Configuration for the sequence-to-sequence model"""
def __init__(self):
# Model Architecture
self.vocabulary_size = 28
self.embedding_size = 256
self.RNN_size = 512
self.drop_out = 0.5
# Training Parameters
self.learning_rate = 0.001
self.batch_size = 32
self.weight_decay = 1e-5
self.gradient_clip = 1.0
# Data
self.data_folder = os.path.join(os.path.dirname(__file__), 'data')
self.val_proportion = 0.1
# Efficiency
self.use_mixed_precision = True
self.num_workers = 0 if os.name == 'nt' else 4
self.pin_memory = True
self.compile_model = True
self.prefetch_factor = 2
self.persistent_workers = False
# Dataset
self.dataset_source = 'local'
self.hf_dataset_name = 'a0a7/Gregg-1916'
class ImageToTextConfig:
"""Configuration for the direct image-to-text model"""
def __init__(self):
# Model Architecture
self.vocabulary_size = 28 # a-z + space + end_token
self.max_text_length = 20 # Maximum text output length
# CNN Feature Extractor
self.cnn_channels = [32, 64, 128, 256] # Progressive channel sizes
self.cnn_kernel_size = 3
self.cnn_padding = 1
self.use_batch_norm = True
self.dropout_cnn = 0.2
# Text Decoder
self.decoder_hidden_size = 512
self.decoder_num_layers = 2
self.decoder_dropout = 0.3
# Training Parameters
self.learning_rate = 0.001
self.batch_size = 32
self.weight_decay = 1e-5
self.gradient_clip = 1.0
# Image Processing
self.image_height = 256
self.image_width = 256
self.image_channels = 1 # Grayscale
# Data
self.data_folder = os.path.join(os.path.dirname(__file__), 'data')
self.val_proportion = 0.1
# Efficiency
self.use_mixed_precision = True
self.num_workers = 0 if os.name == 'nt' else 4
self.pin_memory = True
# Character mapping
self.char_to_idx = {chr(i + ord('a')): i for i in range(26)}
self.char_to_idx[' '] = 26 # Space
self.char_to_idx['<END>'] = 27 # End token
# Reverse mapping
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
def encode_text(self, text):
"""Convert text to sequence of indices"""
indices = []
for char in text.lower():
if char in self.char_to_idx:
indices.append(self.char_to_idx[char])
# Add END token
indices.append(self.char_to_idx['<END>'])
# Pad or truncate to max_length
if len(indices) < self.max_text_length:
indices.extend([self.char_to_idx['<END>']] * (self.max_text_length - len(indices)))
else:
indices = indices[:self.max_text_length]
indices[-1] = self.char_to_idx['<END>'] # Ensure last token is END
return indices
def decode_indices(self, indices):
"""Convert sequence of indices back to text"""
text = ""
for idx in indices:
if idx == self.char_to_idx['<END>']:
break
if idx in self.idx_to_char:
text += self.idx_to_char[idx]
return text