File size: 3,238 Bytes
af587c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pickle
import regex as re
from typing import List, Tuple

class HindiTokenizer:
    def __init__(self, model_path: str = 'bpe_results.pkl'):
        # Load the BPE model
        with open(model_path, 'rb') as f:
            self.merges, self.ids, self.num_merges = pickle.load(f)
            
        # Initialize vocabulary
        self.vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            self.vocab[idx] = self.vocab[p0] + self.vocab[p1]
            
        # Hindi-focused pattern
        self.pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
    
    def tokenize(self, text: str) -> Tuple[List[int], List[str], List[str]]:
        # Get initial tokens using regex
        tokens = re.findall(self.pattern, text)
        
        # Convert tokens to byte sequences and maintain grouping
        byte_tokens = [token.encode('utf-8') for token in tokens]
        token_list = [list(token) for token in byte_tokens]
        
        # Process each token
        final_tokens = []
        for token in token_list:
            current_token = list(token)
            while len(current_token) >= 2:
                stats = self._get_stats([current_token])
                if not stats:
                    break
                pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
                if pair not in self.merges:
                    break
                idx = self.merges[pair]
                current_token = self._merge([current_token], pair, idx)[0]
            final_tokens.extend(current_token)
        
        # Decode the tokens
        decoded_tokens = [self.vocab[idx].decode("utf-8", errors="replace") for idx in final_tokens]
        
        return final_tokens, tokens, decoded_tokens
    
    def _get_stats(self, token_list):
        """Count frequency of pairs across all tokens"""
        counts = {}
        for token in token_list:
            if len(token) < 2:
                continue
            for pair in zip(token, token[1:]):
                counts[pair] = counts.get(pair, 0) + 1
        return counts
    
    def _merge(self, token_list, pair, idx):
        """Merge all occurrences of pair within each token"""
        newids = []
        for token in token_list:
            if len(token) < 2:
                newids.append(token)
                continue
                
            new_token = []
            i = 0
            while i < len(token):
                if i < len(token) - 1 and (token[i], token[i+1]) == pair:
                    new_token.append(idx)
                    i += 2
                else:
                    new_token.append(token[i])
                    i += 1
            newids.append(new_token)
        return newids