Spaces:
Sleeping
Sleeping
File size: 3,603 Bytes
e6769bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""
Configuration classes for Gregg Shorthand Recognition models
"""
import os
class Seq2SeqConfig:
"""Configuration for the sequence-to-sequence model"""
def __init__(self):
# Model Architecture
self.vocabulary_size = 28
self.embedding_size = 256
self.RNN_size = 512
self.drop_out = 0.5
# Training Parameters
self.learning_rate = 0.001
self.batch_size = 32
self.weight_decay = 1e-5
self.gradient_clip = 1.0
# Data
self.data_folder = os.path.join(os.path.dirname(__file__), 'data')
self.val_proportion = 0.1
# Efficiency
self.use_mixed_precision = True
self.num_workers = 0 if os.name == 'nt' else 4
self.pin_memory = True
self.compile_model = True
self.prefetch_factor = 2
self.persistent_workers = False
# Dataset
self.dataset_source = 'local'
self.hf_dataset_name = 'a0a7/Gregg-1916'
class ImageToTextConfig:
"""Configuration for the direct image-to-text model"""
def __init__(self):
# Model Architecture
self.vocabulary_size = 28 # a-z + space + end_token
self.max_text_length = 20 # Maximum text output length
# CNN Feature Extractor
self.cnn_channels = [32, 64, 128, 256] # Progressive channel sizes
self.cnn_kernel_size = 3
self.cnn_padding = 1
self.use_batch_norm = True
self.dropout_cnn = 0.2
# Text Decoder
self.decoder_hidden_size = 512
self.decoder_num_layers = 2
self.decoder_dropout = 0.3
# Training Parameters
self.learning_rate = 0.001
self.batch_size = 32
self.weight_decay = 1e-5
self.gradient_clip = 1.0
# Image Processing
self.image_height = 256
self.image_width = 256
self.image_channels = 1 # Grayscale
# Data
self.data_folder = os.path.join(os.path.dirname(__file__), 'data')
self.val_proportion = 0.1
# Efficiency
self.use_mixed_precision = True
self.num_workers = 0 if os.name == 'nt' else 4
self.pin_memory = True
# Character mapping
self.char_to_idx = {chr(i + ord('a')): i for i in range(26)}
self.char_to_idx[' '] = 26 # Space
self.char_to_idx['<END>'] = 27 # End token
# Reverse mapping
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
def encode_text(self, text):
"""Convert text to sequence of indices"""
indices = []
for char in text.lower():
if char in self.char_to_idx:
indices.append(self.char_to_idx[char])
# Add END token
indices.append(self.char_to_idx['<END>'])
# Pad or truncate to max_length
if len(indices) < self.max_text_length:
indices.extend([self.char_to_idx['<END>']] * (self.max_text_length - len(indices)))
else:
indices = indices[:self.max_text_length]
indices[-1] = self.char_to_idx['<END>'] # Ensure last token is END
return indices
def decode_indices(self, indices):
"""Convert sequence of indices back to text"""
text = ""
for idx in indices:
if idx == self.char_to_idx['<END>']:
break
if idx in self.idx_to_char:
text += self.idx_to_char[idx]
return text
|