File size: 1,628 Bytes
56bad2a
 
 
 
 
 
 
d329cab
 
 
 
 
 
 
 
 
 
56bad2a
 
 
623e9da
56bad2a
 
623e9da
56bad2a
623e9da
 
56bad2a
 
d329cab
56bad2a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import PretrainedConfig, PreTrainedTokenizerBase
from freegroup import tools

class GreedyConfig(PretrainedConfig):

    @classmethod
    def from_tokenizer(cls, freegroup_dimension, tokenizer: PreTrainedTokenizerBase, **kwargs):
        config = cls(
            vocab_size = len(tokenizer),
            eos_token_id = tokenizer.eos_token_id,
            pad_token_id = tokenizer.pad_token_id,
            **kwargs
        )
        config._from_tokenizer(freegroup_dimension, tokenizer)
        return config

    def _from_tokenizer(self, freegroup_dimension, tokenizer):

        freegroup_generators = list(range(1, freegroup_dimension + 1))

        self.reciprocals = []
        for x in freegroup_generators:
            a, b = tokenizer.convert_tokens_to_ids([str(x), str(-x)])
            self.reciprocals.append([a, b])

        self.reducables = [[] for _ in range(freegroup_dimension + 1)]
        for reducable, closure_generator in zip(self.reducables, [[x] for x in freegroup_generators] + [freegroup_generators[::]]):
            reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, closure_generator))))
            reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, tools.reciprocal(closure_generator)))))


    def __init__(self, **kwargs):
        # reciporcals: List[List[int]]: i.e. ['x', 'X'], ...
        self.reciprocals = kwargs.pop('reciprocals', None)

        # reducables: List[List[List[int]]]: generators for normal closures, i.e [[[x], [X]], [[y], [Y]], ...]
        self.reducables = kwargs.pop('reducables', None)

        super().__init__(**kwargs)