|
import torch |
|
|
|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
from .configuration_greedy import GreedyConfig |
|
from freegroup import tools |
|
|
|
class GreedyModel(PreTrainedModel): |
|
config_class = GreedyConfig |
|
|
|
def __init__(self, config: GreedyConfig): |
|
super().__init__(config) |
|
self.stub = torch.nn.parameter.Parameter(torch.tensor(0.)) |
|
|
|
def _reduce_step(self, token, stack, reducables): |
|
stack.append(token.item()) |
|
|
|
for reducable in self.config.reciprocals + reducables: |
|
n = len(reducable) |
|
if len(stack) >= len(reducable): |
|
if tools.occurs(stack[-n:], reducable * 2): |
|
del stack[-n:] |
|
|
|
return stack |
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
past = kwargs.pop('past', None) |
|
return {'input_ids': input_ids, 'past': past} |
|
|
|
def forward(self, input_ids = None, past = None, **kwargs): |
|
|
|
assert (input_ids is not None), "Can't be None" |
|
|
|
batch_size, sequence_length = input_ids.shape |
|
|
|
if past is None: |
|
stacks = [[[] for _ in range(len(self.config.reducables))] for _ in range(batch_size)] |
|
hidden_states = None |
|
else: |
|
stacks, hidden_states = past |
|
|
|
begin_idx = 0 if hidden_states is None else hidden_states.size(0) |
|
|
|
for t in range(begin_idx, sequence_length): |
|
last_hidden_states = torch.zeros((batch_size, self.config.vocab_size)) |
|
|
|
for batch_idx, word in enumerate(input_ids): |
|
for stack, reducables in zip(stacks[batch_idx], self.config.reducables): |
|
|
|
self._reduce_step(word[t], stack, reducables) |
|
if not stack: continue |
|
|
|
last = stack[-1] |
|
|
|
for r in reducables: |
|
if not last in r: |
|
key = r[0] |
|
last_hidden_states[batch_idx][r[0]] += 1 |
|
if last in r: |
|
pos = r.index(last) |
|
key = r[(pos + 1) % len(r)] |
|
last_hidden_states[batch_idx][key] += 1 |
|
for r in self.config.reciprocals: |
|
if last in r: |
|
pos = r.index(last) |
|
key = r[(pos + 1) % len(r)] |
|
last_hidden_states[batch_idx][key] += 1 |
|
|
|
for r in self.config.reciprocals: |
|
if word[t] in r: |
|
pos = r.index(word[t]) |
|
key = r[(pos + 1) % len(r)] |
|
last_hidden_states[batch_idx][key] = -torch.inf |
|
|
|
if all(map(lambda x: len(x) == 0, stacks[batch_idx])): |
|
last_hidden_states[batch_idx][self.config.eos_token_id] = torch.inf |
|
|
|
if hidden_states is None: hidden_states = last_hidden_states.clone().unsqueeze(0) |
|
else: hidden_states = torch.cat((hidden_states, last_hidden_states.unsqueeze(0))) |
|
|
|
return CausalLMOutputWithPast( |
|
logits = hidden_states.permute(1, 0, 2), |
|
past_key_values = (stacks, hidden_states) |
|
) |