File size: 3,603 Bytes
e6769bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Configuration classes for Gregg Shorthand Recognition models
"""

import os

class Seq2SeqConfig:
    """Configuration for the sequence-to-sequence model"""
    
    def __init__(self):
        # Model Architecture
        self.vocabulary_size = 28
        self.embedding_size = 256
        self.RNN_size = 512
        self.drop_out = 0.5
        
        # Training Parameters
        self.learning_rate = 0.001
        self.batch_size = 32
        self.weight_decay = 1e-5
        self.gradient_clip = 1.0
        
        # Data
        self.data_folder = os.path.join(os.path.dirname(__file__), 'data')
        self.val_proportion = 0.1
        
        # Efficiency
        self.use_mixed_precision = True
        self.num_workers = 0 if os.name == 'nt' else 4
        self.pin_memory = True
        self.compile_model = True
        self.prefetch_factor = 2
        self.persistent_workers = False
        
        # Dataset
        self.dataset_source = 'local'
        self.hf_dataset_name = 'a0a7/Gregg-1916'

class ImageToTextConfig:
    """Configuration for the direct image-to-text model"""
    
    def __init__(self):
        # Model Architecture
        self.vocabulary_size = 28  # a-z + space + end_token
        self.max_text_length = 20  # Maximum text output length
        
        # CNN Feature Extractor
        self.cnn_channels = [32, 64, 128, 256]  # Progressive channel sizes
        self.cnn_kernel_size = 3
        self.cnn_padding = 1
        self.use_batch_norm = True
        self.dropout_cnn = 0.2
        
        # Text Decoder
        self.decoder_hidden_size = 512
        self.decoder_num_layers = 2
        self.decoder_dropout = 0.3
        
        # Training Parameters
        self.learning_rate = 0.001
        self.batch_size = 32
        self.weight_decay = 1e-5
        self.gradient_clip = 1.0
        
        # Image Processing
        self.image_height = 256
        self.image_width = 256
        self.image_channels = 1  # Grayscale
        
        # Data
        self.data_folder = os.path.join(os.path.dirname(__file__), 'data')
        self.val_proportion = 0.1
        
        # Efficiency
        self.use_mixed_precision = True
        self.num_workers = 0 if os.name == 'nt' else 4
        self.pin_memory = True
        
        # Character mapping
        self.char_to_idx = {chr(i + ord('a')): i for i in range(26)}
        self.char_to_idx[' '] = 26  # Space
        self.char_to_idx['<END>'] = 27  # End token
        
        # Reverse mapping
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
    
    def encode_text(self, text):
        """Convert text to sequence of indices"""
        indices = []
        for char in text.lower():
            if char in self.char_to_idx:
                indices.append(self.char_to_idx[char])
        
        # Add END token
        indices.append(self.char_to_idx['<END>'])
        
        # Pad or truncate to max_length
        if len(indices) < self.max_text_length:
            indices.extend([self.char_to_idx['<END>']] * (self.max_text_length - len(indices)))
        else:
            indices = indices[:self.max_text_length]
            indices[-1] = self.char_to_idx['<END>']  # Ensure last token is END
        
        return indices
    
    def decode_indices(self, indices):
        """Convert sequence of indices back to text"""
        text = ""
        for idx in indices:
            if idx == self.char_to_idx['<END>']:
                break
            if idx in self.idx_to_char:
                text += self.idx_to_char[idx]
        return text