pedrocas15 commited on
Commit
fd69507
·
verified ·
1 Parent(s): 2a7c88e

Update rpc.py

Browse files
Files changed (1) hide show
  1. rpc.py +17 -11
rpc.py CHANGED
@@ -149,16 +149,13 @@ class DiffAttention(keras.layers.Layer):
149
  att = att1 - lambda_full * att2
150
 
151
  outi = att @ v
152
- attp = self.roll_embeddings(att, self.range_undo)
153
- outp = attp @ pos_src
154
- out = outi + outp
155
  out = out * (1 - self.lambda_init)
156
  return out
157
 
158
 
159
  # Import Model
160
  model = keras.models.load_model(
161
- "rpc_diff_12b_320inp_ct4_01w10.keras",
162
  custom_objects={
163
  "DiffAttention" : DiffAttention,
164
  "SharedEmbedding" : SharedEmbedding,
@@ -187,13 +184,20 @@ def vectorize_texts(all_texts):
187
  # Import Database and All Toks
188
  index = None
189
  all_toks = None
190
- def load_index(index_path="/dev/shm/rpc-vecdb/index"):
 
191
  global index
192
  global all_toks
193
- import ngtpy
194
- index = ngtpy.Index(index_path, read_only=True)
195
- #import faiss
196
- #index = faiss.read_index(index_path + "/index.faiss")
 
 
 
 
 
 
197
  with open(index_path + "/all_toks.json", "r") as f:
198
  all_toks = json.loads(f.read())
199
 
@@ -209,8 +213,10 @@ def generate(text, use_rpc=True, max_tokens=128):
209
  enc_text = enc_text[-input_size:]
210
  if use_rpc:
211
  xq = vectorize_texts([enc_text])[-1]
212
- _id = index.search(xq, size=1, epsilon=1)[0][0]
213
- #_id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
 
 
214
  if all_toks[_id] in carry_toks:
215
  tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
216
  if tmp in enc_text:
 
149
  att = att1 - lambda_full * att2
150
 
151
  outi = att @ v
 
 
 
152
  out = out * (1 - self.lambda_init)
153
  return out
154
 
155
 
156
  # Import Model
157
  model = keras.models.load_model(
158
+ "rpc.keras",
159
  custom_objects={
160
  "DiffAttention" : DiffAttention,
161
  "SharedEmbedding" : SharedEmbedding,
 
184
  # Import Database and All Toks
185
  index = None
186
  all_toks = None
187
+ index_type = None
188
+ def load_index(index_path="/dev/shm/rpc-vecdb/index", idx_type="ngt"):
189
  global index
190
  global all_toks
191
+ global index_type
192
+ index_type = idx_type
193
+ if idx_type == "ngt":
194
+ import ngtpy
195
+ index = ngtpy.Index(index_path, read_only=True)
196
+ elif idx_type == "faiss":
197
+ import faiss
198
+ index = faiss.read_index(index_path + "/index.faiss")
199
+ else:
200
+ raise ValueError("Unknown index type")
201
  with open(index_path + "/all_toks.json", "r") as f:
202
  all_toks = json.loads(f.read())
203
 
 
213
  enc_text = enc_text[-input_size:]
214
  if use_rpc:
215
  xq = vectorize_texts([enc_text])[-1]
216
+ if index_type == "ngt":
217
+ _id = index.search(xq, size=1, epsilon=1)[0][0]
218
+ else:
219
+ _id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
220
  if all_toks[_id] in carry_toks:
221
  tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
222
  if tmp in enc_text: