greedy-intersection / modeling_greedy.py
kibrq's picture
Update model
56bad2a
import torch
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_greedy import GreedyConfig
from freegroup import tools
class GreedyModel(PreTrainedModel):
config_class = GreedyConfig
def __init__(self, config: GreedyConfig):
super().__init__(config)
self.stub = torch.nn.parameter.Parameter(torch.tensor(0.))
def _reduce_step(self, token, stack, reducables):
stack.append(token.item())
for reducable in self.config.reciprocals + reducables:
n = len(reducable)
if len(stack) >= len(reducable):
if tools.occurs(stack[-n:], reducable * 2):
del stack[-n:]
return stack
def prepare_inputs_for_generation(self, input_ids, **kwargs):
past = kwargs.pop('past', None)
return {'input_ids': input_ids, 'past': past}
def forward(self, input_ids = None, past = None, **kwargs):
assert (input_ids is not None), "Can't be None"
batch_size, sequence_length = input_ids.shape
if past is None:
stacks = [[[] for _ in range(len(self.config.reducables))] for _ in range(batch_size)]
hidden_states = None
else:
stacks, hidden_states = past
begin_idx = 0 if hidden_states is None else hidden_states.size(0)
for t in range(begin_idx, sequence_length):
last_hidden_states = torch.zeros((batch_size, self.config.vocab_size))
for batch_idx, word in enumerate(input_ids):
for stack, reducables in zip(stacks[batch_idx], self.config.reducables):
self._reduce_step(word[t], stack, reducables)
if not stack: continue
last = stack[-1]
for r in reducables:
if not last in r:
key = r[0]
last_hidden_states[batch_idx][r[0]] += 1
if last in r:
pos = r.index(last)
key = r[(pos + 1) % len(r)]
last_hidden_states[batch_idx][key] += 1
for r in self.config.reciprocals:
if last in r:
pos = r.index(last)
key = r[(pos + 1) % len(r)]
last_hidden_states[batch_idx][key] += 1
for r in self.config.reciprocals:
if word[t] in r:
pos = r.index(word[t])
key = r[(pos + 1) % len(r)]
last_hidden_states[batch_idx][key] = -torch.inf
if all(map(lambda x: len(x) == 0, stacks[batch_idx])):
last_hidden_states[batch_idx][self.config.eos_token_id] = torch.inf
if hidden_states is None: hidden_states = last_hidden_states.clone().unsqueeze(0)
else: hidden_states = torch.cat((hidden_states, last_hidden_states.unsqueeze(0)))
return CausalLMOutputWithPast(
logits = hidden_states.permute(1, 0, 2),
past_key_values = (stacks, hidden_states)
)