Spaces:
Sleeping
Sleeping
""" | |
Configuration classes for Gregg Shorthand Recognition models | |
""" | |
import os | |
class Seq2SeqConfig: | |
"""Configuration for the sequence-to-sequence model""" | |
def __init__(self): | |
# Model Architecture | |
self.vocabulary_size = 28 | |
self.embedding_size = 256 | |
self.RNN_size = 512 | |
self.drop_out = 0.5 | |
# Training Parameters | |
self.learning_rate = 0.001 | |
self.batch_size = 32 | |
self.weight_decay = 1e-5 | |
self.gradient_clip = 1.0 | |
# Data | |
self.data_folder = os.path.join(os.path.dirname(__file__), 'data') | |
self.val_proportion = 0.1 | |
# Efficiency | |
self.use_mixed_precision = True | |
self.num_workers = 0 if os.name == 'nt' else 4 | |
self.pin_memory = True | |
self.compile_model = True | |
self.prefetch_factor = 2 | |
self.persistent_workers = False | |
# Dataset | |
self.dataset_source = 'local' | |
self.hf_dataset_name = 'a0a7/Gregg-1916' | |
class ImageToTextConfig: | |
"""Configuration for the direct image-to-text model""" | |
def __init__(self): | |
# Model Architecture | |
self.vocabulary_size = 28 # a-z + space + end_token | |
self.max_text_length = 20 # Maximum text output length | |
# CNN Feature Extractor | |
self.cnn_channels = [32, 64, 128, 256] # Progressive channel sizes | |
self.cnn_kernel_size = 3 | |
self.cnn_padding = 1 | |
self.use_batch_norm = True | |
self.dropout_cnn = 0.2 | |
# Text Decoder | |
self.decoder_hidden_size = 512 | |
self.decoder_num_layers = 2 | |
self.decoder_dropout = 0.3 | |
# Training Parameters | |
self.learning_rate = 0.001 | |
self.batch_size = 32 | |
self.weight_decay = 1e-5 | |
self.gradient_clip = 1.0 | |
# Image Processing | |
self.image_height = 256 | |
self.image_width = 256 | |
self.image_channels = 1 # Grayscale | |
# Data | |
self.data_folder = os.path.join(os.path.dirname(__file__), 'data') | |
self.val_proportion = 0.1 | |
# Efficiency | |
self.use_mixed_precision = True | |
self.num_workers = 0 if os.name == 'nt' else 4 | |
self.pin_memory = True | |
# Character mapping | |
self.char_to_idx = {chr(i + ord('a')): i for i in range(26)} | |
self.char_to_idx[' '] = 26 # Space | |
self.char_to_idx['<END>'] = 27 # End token | |
# Reverse mapping | |
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()} | |
def encode_text(self, text): | |
"""Convert text to sequence of indices""" | |
indices = [] | |
for char in text.lower(): | |
if char in self.char_to_idx: | |
indices.append(self.char_to_idx[char]) | |
# Add END token | |
indices.append(self.char_to_idx['<END>']) | |
# Pad or truncate to max_length | |
if len(indices) < self.max_text_length: | |
indices.extend([self.char_to_idx['<END>']] * (self.max_text_length - len(indices))) | |
else: | |
indices = indices[:self.max_text_length] | |
indices[-1] = self.char_to_idx['<END>'] # Ensure last token is END | |
return indices | |
def decode_indices(self, indices): | |
"""Convert sequence of indices back to text""" | |
text = "" | |
for idx in indices: | |
if idx == self.char_to_idx['<END>']: | |
break | |
if idx in self.idx_to_char: | |
text += self.idx_to_char[idx] | |
return text | |