Upload model.py with huggingface_hub
Browse files
model.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import random
|
4 |
+
|
5 |
+
class Vocab:
|
6 |
+
def __init__(self, node_dict, nodeindex_dict, edge_dict, edge_decode_dict):
|
7 |
+
self.node_dict = node_dict
|
8 |
+
self.nodeindex_dict = nodeindex_dict
|
9 |
+
self.edge_dict = edge_dict
|
10 |
+
self.edge_decode_dict = edge_decode_dict
|
11 |
+
|
12 |
+
def __call__(self, x):
|
13 |
+
if isinstance(x, list):
|
14 |
+
return [self.__call__(_) for _ in x]
|
15 |
+
else:
|
16 |
+
return self.fetch(x)
|
17 |
+
|
18 |
+
def fetch(self, x):
|
19 |
+
s, t = x.split("->")
|
20 |
+
return self.edge_dict[s][t] if s in self.edge_dict and t in self.edge_dict[s] else self.edge_dict["<unk>"]["<unk>"]
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def from_node_dict(cls, dictname):
|
24 |
+
nodeindex_dict = dict()
|
25 |
+
edge_dict = dict()
|
26 |
+
edge_decode_dict = dict()
|
27 |
+
for s in dictname:
|
28 |
+
nodeindex_dict[dictname[s]] = s
|
29 |
+
edge_dict[s] = {}
|
30 |
+
for t in dictname:
|
31 |
+
edge_dict[s][t] = (dictname[s], dictname[t])
|
32 |
+
edge_decode_dict[(dictname[s], dictname[t])] = "->".join([s, t])
|
33 |
+
return cls(None, nodeindex_dict, edge_dict, edge_decode_dict)
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def from_edge(cls, filename):
|
37 |
+
edge_dict = dict()
|
38 |
+
edge_dict["<unk>"] = {}
|
39 |
+
edge_dict["<unk>"]["<unk>"] = (0, 0)
|
40 |
+
edge_decode_dict = dict()
|
41 |
+
with open(filename) as f:
|
42 |
+
for line in f:
|
43 |
+
s, t = line.strip().split("->")
|
44 |
+
if s not in edge_dict:
|
45 |
+
i = len(edge_dict)
|
46 |
+
j = 0
|
47 |
+
edge_dict[s] = dict()
|
48 |
+
else:
|
49 |
+
i = edge_dict[s][list(edge_dict[s].keys())[0]][0]
|
50 |
+
j = len(edge_dict[s])
|
51 |
+
edge_dict[s][t] = (i, j)
|
52 |
+
edge_decode_dict[(i, j)] = "->".join([s, t])
|
53 |
+
return cls(None, edge_dict, edge_decode_dict)
|
54 |
+
|
55 |
+
def get_neighbor_of_edge(self, key, k):
|
56 |
+
s, t = key.split("->")
|
57 |
+
_s = s if s in self.edge_dict else "<unk>"
|
58 |
+
ret = ["->".join([_s, _t]) for _t in self.edge_dict[_s].keys() if _t != t]
|
59 |
+
random.shuffle(ret)
|
60 |
+
return ret[:k] if k != -1 else ret
|
61 |
+
|
62 |
+
def get_neighbor_of_node(self, key, k):
|
63 |
+
s = self.nodeindex_dict[key]
|
64 |
+
ret = ["->".join([s, _t]) for _t in self.edge_dict[s].keys() if _t != s]
|
65 |
+
random.shuffle(ret)
|
66 |
+
return ret[:k] if k != -1 else ret
|
67 |
+
|
68 |
+
def get_neighbor_of_edge_broadcast(self, key, edges, k=100):
|
69 |
+
s, t = key.split("->")
|
70 |
+
_ret = [_t for _t in self.edge_dict[s].keys() if _t != t]
|
71 |
+
random.shuffle(_ret)
|
72 |
+
ret = []
|
73 |
+
for edge in edges:
|
74 |
+
s, t = edge.split("->")
|
75 |
+
ret += [["->".join([s, _t]) for _t in _ret[:k]]]
|
76 |
+
return ret
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def to_path(tokens):
|
80 |
+
path = []
|
81 |
+
for left, right in zip(tokens[:-1], tokens[1:]):
|
82 |
+
path.append("->".join([left, right]))
|
83 |
+
return path
|
84 |
+
|
85 |
+
def get_edge_of_node(self, key):
|
86 |
+
return list(self.edge_dict[key].values())
|
87 |
+
|
88 |
+
def decode(self, x):
|
89 |
+
return self.edge_decode_dict[x]
|
90 |
+
|
91 |
+
|
92 |
+
class BraLM(nn.Module):
|
93 |
+
def __init__(self, hidden_size):
|
94 |
+
super().__init__()
|
95 |
+
self.hidden_size = hidden_size
|
96 |
+
self.network = nn.ParameterList()
|
97 |
+
self.bias = nn.ParameterList()
|
98 |
+
self.sigmoid = nn.GELU()
|
99 |
+
self.positions = nn.Parameter(torch.ones(1, 512, 1))
|
100 |
+
self.device = None
|
101 |
+
|
102 |
+
def prepare_network(self, vocab):
|
103 |
+
for s in vocab.edge_dict:
|
104 |
+
self.network.append(nn.Parameter(torch.randn(len(vocab.edge_dict[s]), self.hidden_size, self.hidden_size).uniform_(-0.5, 0.5)))
|
105 |
+
self.bias.append(nn.Parameter(torch.randn(len(vocab.edge_dict[s]), 1, self.hidden_size).uniform_(-0.5, 0.5)))
|
106 |
+
|
107 |
+
def _network(self, x, y):
|
108 |
+
return self.network[x][y]
|
109 |
+
|
110 |
+
def to_device(self, device):
|
111 |
+
self.network.to(device)
|
112 |
+
self.positions.data = self.positions.data.to(device)
|
113 |
+
self.device = device
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def _reshape12(x):
|
117 |
+
return x.reshape(-1, x.size(-2), x.size(-1))
|
118 |
+
|
119 |
+
def get_positional_encoding(self, seq_len, d_model):
|
120 |
+
position = torch.arange(0, seq_len).reshape(-1, 1)
|
121 |
+
div_term = 10000.0 ** (torch.arange(0, d_model, 2) / d_model)
|
122 |
+
position_encoding = torch.zeros(seq_len, d_model)
|
123 |
+
position_encoding[:, 0::2] = torch.sin(position * div_term)
|
124 |
+
position_encoding[:, 1::2] = torch.cos(position * div_term)
|
125 |
+
return position_encoding.unsqueeze(0).to(self.device)
|
126 |
+
|
127 |
+
|
128 |
+
def get_initial_tensor(self, batch_size):
|
129 |
+
energy_tensor = torch.ones(batch_size, 1, self.hidden_size) / self.hidden_size
|
130 |
+
return energy_tensor.to(self.device)
|
131 |
+
|
132 |
+
|
133 |
+
def decode(self, start, vocab, max_new_tokens=16, do_sample=False, temperature=1):
|
134 |
+
ret = []
|
135 |
+
pe = self.get_positional_encoding(512, self.hidden_size)
|
136 |
+
for i, pair in enumerate(start):
|
137 |
+
if i == 0:
|
138 |
+
energy_tensor = self.get_initial_tensor(batch_size=1).squeeze(0)
|
139 |
+
else:
|
140 |
+
energy_tensor = (energy_cache * self.positions[:, :i, :].softmax(1)).sum(1, keepdim=True).squeeze(0)
|
141 |
+
w = self._network(pair[0], pair[1]).to(self.device)
|
142 |
+
b = self.bias[pair[0]][pair[1]].to(self.device)
|
143 |
+
|
144 |
+
energy_tensor = self.sigmoid(energy_tensor.mm(w) + b + pe.squeeze(0)[i])
|
145 |
+
if i == 0:
|
146 |
+
energy_cache = energy_tensor
|
147 |
+
else:
|
148 |
+
energy_cache = torch.cat([energy_cache, energy_tensor], dim=0)
|
149 |
+
ret += [pair]
|
150 |
+
x = pair[1]
|
151 |
+
prev_i = len(start)
|
152 |
+
|
153 |
+
for i in range(max_new_tokens):
|
154 |
+
candidates = vocab(vocab.get_neighbor_of_node(x, -1))
|
155 |
+
all_w = torch.cat([self._network(z[0], z[1]).unsqueeze(0) for z in candidates], dim=0).to(self.device)
|
156 |
+
all_b = torch.cat([self.bias[z[0]][z[1]].unsqueeze(0) for z in candidates], dim=0).to(self.device)
|
157 |
+
|
158 |
+
curr_i = prev_i + i
|
159 |
+
energy_tensor = (energy_cache * self.positions.squeeze(0)[:curr_i, :].softmax(0)).sum(0, keepdim=True)
|
160 |
+
expand_energy_tensor = energy_tensor.unsqueeze(0).repeat(all_w.size(0), 1, 1)
|
161 |
+
|
162 |
+
nxt_energy_tensor = self.sigmoid(expand_energy_tensor.bmm(all_w)+all_b+pe[:,i])
|
163 |
+
|
164 |
+
energy = nxt_energy_tensor.norm(2, (-2,-1))
|
165 |
+
|
166 |
+
probs = torch.softmax(energy, dim=-1)
|
167 |
+
if temperature > 0:
|
168 |
+
probs = probs / temperature
|
169 |
+
if do_sample:
|
170 |
+
index = torch.multinomial(probs, 1).item()
|
171 |
+
else:
|
172 |
+
index = probs.argmax(-1).item()
|
173 |
+
|
174 |
+
y = candidates[index][-1]
|
175 |
+
ret += [(x, y)]
|
176 |
+
|
177 |
+
energy_tensor = nxt_energy_tensor[index, :, :]
|
178 |
+
x = y
|
179 |
+
|
180 |
+
energy_cache = torch.cat([energy_cache, energy_tensor], dim=0)
|
181 |
+
|
182 |
+
return ret
|