BriLLM commited on
Commit
7d2abd5
·
verified ·
1 Parent(s): bec855d

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +182 -0
model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+
5
+ class Vocab:
6
+ def __init__(self, node_dict, nodeindex_dict, edge_dict, edge_decode_dict):
7
+ self.node_dict = node_dict
8
+ self.nodeindex_dict = nodeindex_dict
9
+ self.edge_dict = edge_dict
10
+ self.edge_decode_dict = edge_decode_dict
11
+
12
+ def __call__(self, x):
13
+ if isinstance(x, list):
14
+ return [self.__call__(_) for _ in x]
15
+ else:
16
+ return self.fetch(x)
17
+
18
+ def fetch(self, x):
19
+ s, t = x.split("->")
20
+ return self.edge_dict[s][t] if s in self.edge_dict and t in self.edge_dict[s] else self.edge_dict["<unk>"]["<unk>"]
21
+
22
+ @classmethod
23
+ def from_node_dict(cls, dictname):
24
+ nodeindex_dict = dict()
25
+ edge_dict = dict()
26
+ edge_decode_dict = dict()
27
+ for s in dictname:
28
+ nodeindex_dict[dictname[s]] = s
29
+ edge_dict[s] = {}
30
+ for t in dictname:
31
+ edge_dict[s][t] = (dictname[s], dictname[t])
32
+ edge_decode_dict[(dictname[s], dictname[t])] = "->".join([s, t])
33
+ return cls(None, nodeindex_dict, edge_dict, edge_decode_dict)
34
+
35
+ @classmethod
36
+ def from_edge(cls, filename):
37
+ edge_dict = dict()
38
+ edge_dict["<unk>"] = {}
39
+ edge_dict["<unk>"]["<unk>"] = (0, 0)
40
+ edge_decode_dict = dict()
41
+ with open(filename) as f:
42
+ for line in f:
43
+ s, t = line.strip().split("->")
44
+ if s not in edge_dict:
45
+ i = len(edge_dict)
46
+ j = 0
47
+ edge_dict[s] = dict()
48
+ else:
49
+ i = edge_dict[s][list(edge_dict[s].keys())[0]][0]
50
+ j = len(edge_dict[s])
51
+ edge_dict[s][t] = (i, j)
52
+ edge_decode_dict[(i, j)] = "->".join([s, t])
53
+ return cls(None, edge_dict, edge_decode_dict)
54
+
55
+ def get_neighbor_of_edge(self, key, k):
56
+ s, t = key.split("->")
57
+ _s = s if s in self.edge_dict else "<unk>"
58
+ ret = ["->".join([_s, _t]) for _t in self.edge_dict[_s].keys() if _t != t]
59
+ random.shuffle(ret)
60
+ return ret[:k] if k != -1 else ret
61
+
62
+ def get_neighbor_of_node(self, key, k):
63
+ s = self.nodeindex_dict[key]
64
+ ret = ["->".join([s, _t]) for _t in self.edge_dict[s].keys() if _t != s]
65
+ random.shuffle(ret)
66
+ return ret[:k] if k != -1 else ret
67
+
68
+ def get_neighbor_of_edge_broadcast(self, key, edges, k=100):
69
+ s, t = key.split("->")
70
+ _ret = [_t for _t in self.edge_dict[s].keys() if _t != t]
71
+ random.shuffle(_ret)
72
+ ret = []
73
+ for edge in edges:
74
+ s, t = edge.split("->")
75
+ ret += [["->".join([s, _t]) for _t in _ret[:k]]]
76
+ return ret
77
+
78
+ @staticmethod
79
+ def to_path(tokens):
80
+ path = []
81
+ for left, right in zip(tokens[:-1], tokens[1:]):
82
+ path.append("->".join([left, right]))
83
+ return path
84
+
85
+ def get_edge_of_node(self, key):
86
+ return list(self.edge_dict[key].values())
87
+
88
+ def decode(self, x):
89
+ return self.edge_decode_dict[x]
90
+
91
+
92
+ class BraLM(nn.Module):
93
+ def __init__(self, hidden_size):
94
+ super().__init__()
95
+ self.hidden_size = hidden_size
96
+ self.network = nn.ParameterList()
97
+ self.bias = nn.ParameterList()
98
+ self.sigmoid = nn.GELU()
99
+ self.positions = nn.Parameter(torch.ones(1, 512, 1))
100
+ self.device = None
101
+
102
+ def prepare_network(self, vocab):
103
+ for s in vocab.edge_dict:
104
+ self.network.append(nn.Parameter(torch.randn(len(vocab.edge_dict[s]), self.hidden_size, self.hidden_size).uniform_(-0.5, 0.5)))
105
+ self.bias.append(nn.Parameter(torch.randn(len(vocab.edge_dict[s]), 1, self.hidden_size).uniform_(-0.5, 0.5)))
106
+
107
+ def _network(self, x, y):
108
+ return self.network[x][y]
109
+
110
+ def to_device(self, device):
111
+ self.network.to(device)
112
+ self.positions.data = self.positions.data.to(device)
113
+ self.device = device
114
+
115
+ @staticmethod
116
+ def _reshape12(x):
117
+ return x.reshape(-1, x.size(-2), x.size(-1))
118
+
119
+ def get_positional_encoding(self, seq_len, d_model):
120
+ position = torch.arange(0, seq_len).reshape(-1, 1)
121
+ div_term = 10000.0 ** (torch.arange(0, d_model, 2) / d_model)
122
+ position_encoding = torch.zeros(seq_len, d_model)
123
+ position_encoding[:, 0::2] = torch.sin(position * div_term)
124
+ position_encoding[:, 1::2] = torch.cos(position * div_term)
125
+ return position_encoding.unsqueeze(0).to(self.device)
126
+
127
+
128
+ def get_initial_tensor(self, batch_size):
129
+ energy_tensor = torch.ones(batch_size, 1, self.hidden_size) / self.hidden_size
130
+ return energy_tensor.to(self.device)
131
+
132
+
133
+ def decode(self, start, vocab, max_new_tokens=16, do_sample=False, temperature=1):
134
+ ret = []
135
+ pe = self.get_positional_encoding(512, self.hidden_size)
136
+ for i, pair in enumerate(start):
137
+ if i == 0:
138
+ energy_tensor = self.get_initial_tensor(batch_size=1).squeeze(0)
139
+ else:
140
+ energy_tensor = (energy_cache * self.positions[:, :i, :].softmax(1)).sum(1, keepdim=True).squeeze(0)
141
+ w = self._network(pair[0], pair[1]).to(self.device)
142
+ b = self.bias[pair[0]][pair[1]].to(self.device)
143
+
144
+ energy_tensor = self.sigmoid(energy_tensor.mm(w) + b + pe.squeeze(0)[i])
145
+ if i == 0:
146
+ energy_cache = energy_tensor
147
+ else:
148
+ energy_cache = torch.cat([energy_cache, energy_tensor], dim=0)
149
+ ret += [pair]
150
+ x = pair[1]
151
+ prev_i = len(start)
152
+
153
+ for i in range(max_new_tokens):
154
+ candidates = vocab(vocab.get_neighbor_of_node(x, -1))
155
+ all_w = torch.cat([self._network(z[0], z[1]).unsqueeze(0) for z in candidates], dim=0).to(self.device)
156
+ all_b = torch.cat([self.bias[z[0]][z[1]].unsqueeze(0) for z in candidates], dim=0).to(self.device)
157
+
158
+ curr_i = prev_i + i
159
+ energy_tensor = (energy_cache * self.positions.squeeze(0)[:curr_i, :].softmax(0)).sum(0, keepdim=True)
160
+ expand_energy_tensor = energy_tensor.unsqueeze(0).repeat(all_w.size(0), 1, 1)
161
+
162
+ nxt_energy_tensor = self.sigmoid(expand_energy_tensor.bmm(all_w)+all_b+pe[:,i])
163
+
164
+ energy = nxt_energy_tensor.norm(2, (-2,-1))
165
+
166
+ probs = torch.softmax(energy, dim=-1)
167
+ if temperature > 0:
168
+ probs = probs / temperature
169
+ if do_sample:
170
+ index = torch.multinomial(probs, 1).item()
171
+ else:
172
+ index = probs.argmax(-1).item()
173
+
174
+ y = candidates[index][-1]
175
+ ret += [(x, y)]
176
+
177
+ energy_tensor = nxt_energy_tensor[index, :, :]
178
+ x = y
179
+
180
+ energy_cache = torch.cat([energy_cache, energy_tensor], dim=0)
181
+
182
+ return ret