import torch from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_greedy import GreedyConfig from freegroup import tools class GreedyModel(PreTrainedModel): config_class = GreedyConfig def __init__(self, config: GreedyConfig): super().__init__(config) self.stub = torch.nn.parameter.Parameter(torch.tensor(0.)) def _reduce_step(self, token, stack, reducables): stack.append(token.item()) for reducable in self.config.reciprocals + reducables: n = len(reducable) if len(stack) >= len(reducable): if tools.occurs(stack[-n:], reducable * 2): del stack[-n:] return stack def prepare_inputs_for_generation(self, input_ids, **kwargs): past = kwargs.pop('past', None) return {'input_ids': input_ids, 'past': past} def forward(self, input_ids = None, past = None, **kwargs): assert (input_ids is not None), "Can't be None" batch_size, sequence_length = input_ids.shape if past is None: stacks = [[[] for _ in range(len(self.config.reducables))] for _ in range(batch_size)] hidden_states = None else: stacks, hidden_states = past begin_idx = 0 if hidden_states is None else hidden_states.size(0) for t in range(begin_idx, sequence_length): last_hidden_states = torch.zeros((batch_size, self.config.vocab_size)) for batch_idx, word in enumerate(input_ids): for stack, reducables in zip(stacks[batch_idx], self.config.reducables): self._reduce_step(word[t], stack, reducables) if not stack: continue last = stack[-1] for r in reducables: if not last in r: key = r[0] last_hidden_states[batch_idx][r[0]] += 1 if last in r: pos = r.index(last) key = r[(pos + 1) % len(r)] last_hidden_states[batch_idx][key] += 1 for r in self.config.reciprocals: if last in r: pos = r.index(last) key = r[(pos + 1) % len(r)] last_hidden_states[batch_idx][key] += 1 for r in self.config.reciprocals: if word[t] in r: pos = r.index(word[t]) key = r[(pos + 1) % len(r)] last_hidden_states[batch_idx][key] = -torch.inf if all(map(lambda x: len(x) == 0, stacks[batch_idx])): last_hidden_states[batch_idx][self.config.eos_token_id] = torch.inf if hidden_states is None: hidden_states = last_hidden_states.clone().unsqueeze(0) else: hidden_states = torch.cat((hidden_states, last_hidden_states.unsqueeze(0))) return CausalLMOutputWithPast( logits = hidden_states.permute(1, 0, 2), past_key_values = (stacks, hidden_states) )