|
from transformers import PretrainedConfig, PreTrainedTokenizerBase |
|
from freegroup import tools |
|
|
|
class GreedyConfig(PretrainedConfig): |
|
|
|
@classmethod |
|
def from_tokenizer(cls, freegroup_dimension, tokenizer: PreTrainedTokenizerBase, **kwargs): |
|
config = cls( |
|
vocab_size = len(tokenizer), |
|
eos_token_id = tokenizer.eos_token_id, |
|
pad_token_id = tokenizer.pad_token_id, |
|
**kwargs |
|
) |
|
config._from_tokenizer(freegroup_dimension, tokenizer) |
|
return config |
|
|
|
def _from_tokenizer(self, freegroup_dimension, tokenizer): |
|
|
|
freegroup_generators = list(range(1, freegroup_dimension + 1)) |
|
|
|
self.reciprocals = [] |
|
for x in freegroup_generators: |
|
a, b = tokenizer.convert_tokens_to_ids([str(x), str(-x)]) |
|
self.reciprocals.append([a, b]) |
|
|
|
self.reducables = [[] for _ in range(freegroup_dimension + 1)] |
|
for reducable, closure_generator in zip(self.reducables, [[x] for x in freegroup_generators] + [freegroup_generators[::]]): |
|
reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, closure_generator)))) |
|
reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, tools.reciprocal(closure_generator))))) |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
|
self.reciprocals = kwargs.pop('reciprocals', None) |
|
|
|
|
|
self.reducables = kwargs.pop('reducables', None) |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|