File size: 3,323 Bytes
56bad2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch

from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast

from .configuration_greedy import GreedyConfig
from freegroup import tools

class GreedyModel(PreTrainedModel):
    config_class = GreedyConfig
    
    def __init__(self, config: GreedyConfig): 
        super().__init__(config)
        self.stub = torch.nn.parameter.Parameter(torch.tensor(0.))

    def _reduce_step(self, token, stack, reducables):
        stack.append(token.item())

        for reducable in self.config.reciprocals + reducables:
            n = len(reducable)
            if len(stack) >= len(reducable):
                if tools.occurs(stack[-n:], reducable * 2):
                    del stack[-n:]

        return stack

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        past = kwargs.pop('past', None)
        return {'input_ids': input_ids, 'past': past}

    def forward(self, input_ids = None, past = None, **kwargs):

        assert (input_ids is not None), "Can't be None"
        
        batch_size, sequence_length = input_ids.shape

        if past is None:
            stacks = [[[] for _ in range(len(self.config.reducables))] for _ in range(batch_size)]
            hidden_states = None
        else:
            stacks, hidden_states = past

        begin_idx = 0 if hidden_states is None else hidden_states.size(0)

        for t in range(begin_idx, sequence_length):
            last_hidden_states = torch.zeros((batch_size, self.config.vocab_size))
            
            for batch_idx, word in enumerate(input_ids):
                for stack, reducables in zip(stacks[batch_idx], self.config.reducables):
                    
                    self._reduce_step(word[t], stack, reducables)
                    if not stack: continue
                    
                    last = stack[-1]

                    for r in reducables:
                        if not last in r:
                            key = r[0]
                            last_hidden_states[batch_idx][r[0]] += 1
                        if last in r:
                            pos = r.index(last)
                            key = r[(pos + 1) % len(r)]
                            last_hidden_states[batch_idx][key] += 1
                    for r in self.config.reciprocals:
                        if last in r:
                            pos = r.index(last)
                            key = r[(pos + 1) % len(r)]
                            last_hidden_states[batch_idx][key] += 1

                for r in self.config.reciprocals:
                    if word[t] in r:
                        pos = r.index(word[t])
                        key = r[(pos + 1) % len(r)]
                        last_hidden_states[batch_idx][key] = -torch.inf
            
                if all(map(lambda x: len(x) == 0, stacks[batch_idx])):
                    last_hidden_states[batch_idx][self.config.eos_token_id] = torch.inf

            if hidden_states is None: hidden_states = last_hidden_states.clone().unsqueeze(0)
            else: hidden_states = torch.cat((hidden_states, last_hidden_states.unsqueeze(0)))

        return CausalLMOutputWithPast(
            logits = hidden_states.permute(1, 0, 2),
            past_key_values = (stacks, hidden_states)
        )