from transformers import PretrainedConfig, PreTrainedTokenizerBase from freegroup import tools class GreedyConfig(PretrainedConfig): @classmethod def from_tokenizer(cls, freegroup_dimension, tokenizer: PreTrainedTokenizerBase, **kwargs): config = cls( vocab_size = len(tokenizer), eos_token_id = tokenizer.eos_token_id, pad_token_id = tokenizer.pad_token_id, **kwargs ) config._from_tokenizer(freegroup_dimension, tokenizer) return config def _from_tokenizer(self, freegroup_dimension, tokenizer): freegroup_generators = list(range(1, freegroup_dimension + 1)) self.reciprocals = [] for x in freegroup_generators: a, b = tokenizer.convert_tokens_to_ids([str(x), str(-x)]) self.reciprocals.append([a, b]) self.reducables = [[] for _ in range(freegroup_dimension + 1)] for reducable, closure_generator in zip(self.reducables, [[x] for x in freegroup_generators] + [freegroup_generators[::]]): reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, closure_generator)))) reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, tools.reciprocal(closure_generator))))) def __init__(self, **kwargs): # reciporcals: List[List[int]]: i.e. ['x', 'X'], ... self.reciprocals = kwargs.pop('reciprocals', None) # reducables: List[List[List[int]]]: generators for normal closures, i.e [[[x], [X]], [[y], [Y]], ...] self.reducables = kwargs.pop('reducables', None) super().__init__(**kwargs)