greedy-intersection / configuration_greedy.py
kibrq's picture
Update model
623e9da
from transformers import PretrainedConfig, PreTrainedTokenizerBase
from freegroup import tools
class GreedyConfig(PretrainedConfig):
@classmethod
def from_tokenizer(cls, freegroup_dimension, tokenizer: PreTrainedTokenizerBase, **kwargs):
config = cls(
vocab_size = len(tokenizer),
eos_token_id = tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id,
**kwargs
)
config._from_tokenizer(freegroup_dimension, tokenizer)
return config
def _from_tokenizer(self, freegroup_dimension, tokenizer):
freegroup_generators = list(range(1, freegroup_dimension + 1))
self.reciprocals = []
for x in freegroup_generators:
a, b = tokenizer.convert_tokens_to_ids([str(x), str(-x)])
self.reciprocals.append([a, b])
self.reducables = [[] for _ in range(freegroup_dimension + 1)]
for reducable, closure_generator in zip(self.reducables, [[x] for x in freegroup_generators] + [freegroup_generators[::]]):
reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, closure_generator))))
reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, tools.reciprocal(closure_generator)))))
def __init__(self, **kwargs):
# reciporcals: List[List[int]]: i.e. ['x', 'X'], ...
self.reciprocals = kwargs.pop('reciprocals', None)
# reducables: List[List[List[int]]]: generators for normal closures, i.e [[[x], [X]], [[y], [Y]], ...]
self.reducables = kwargs.pop('reducables', None)
super().__init__(**kwargs)