Spaces:
Runtime error
Runtime error
Update rpc.py
Browse files
rpc.py
CHANGED
|
@@ -149,16 +149,13 @@ class DiffAttention(keras.layers.Layer):
|
|
| 149 |
att = att1 - lambda_full * att2
|
| 150 |
|
| 151 |
outi = att @ v
|
| 152 |
-
attp = self.roll_embeddings(att, self.range_undo)
|
| 153 |
-
outp = attp @ pos_src
|
| 154 |
-
out = outi + outp
|
| 155 |
out = out * (1 - self.lambda_init)
|
| 156 |
return out
|
| 157 |
|
| 158 |
|
| 159 |
# Import Model
|
| 160 |
model = keras.models.load_model(
|
| 161 |
-
"
|
| 162 |
custom_objects={
|
| 163 |
"DiffAttention" : DiffAttention,
|
| 164 |
"SharedEmbedding" : SharedEmbedding,
|
|
@@ -187,13 +184,20 @@ def vectorize_texts(all_texts):
|
|
| 187 |
# Import Database and All Toks
|
| 188 |
index = None
|
| 189 |
all_toks = None
|
| 190 |
-
|
|
|
|
| 191 |
global index
|
| 192 |
global all_toks
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
with open(index_path + "/all_toks.json", "r") as f:
|
| 198 |
all_toks = json.loads(f.read())
|
| 199 |
|
|
@@ -209,8 +213,10 @@ def generate(text, use_rpc=True, max_tokens=128):
|
|
| 209 |
enc_text = enc_text[-input_size:]
|
| 210 |
if use_rpc:
|
| 211 |
xq = vectorize_texts([enc_text])[-1]
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
| 214 |
if all_toks[_id] in carry_toks:
|
| 215 |
tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
|
| 216 |
if tmp in enc_text:
|
|
|
|
| 149 |
att = att1 - lambda_full * att2
|
| 150 |
|
| 151 |
outi = att @ v
|
|
|
|
|
|
|
|
|
|
| 152 |
out = out * (1 - self.lambda_init)
|
| 153 |
return out
|
| 154 |
|
| 155 |
|
| 156 |
# Import Model
|
| 157 |
model = keras.models.load_model(
|
| 158 |
+
"rpc.keras",
|
| 159 |
custom_objects={
|
| 160 |
"DiffAttention" : DiffAttention,
|
| 161 |
"SharedEmbedding" : SharedEmbedding,
|
|
|
|
| 184 |
# Import Database and All Toks
|
| 185 |
index = None
|
| 186 |
all_toks = None
|
| 187 |
+
index_type = None
|
| 188 |
+
def load_index(index_path="/dev/shm/rpc-vecdb/index", idx_type="ngt"):
|
| 189 |
global index
|
| 190 |
global all_toks
|
| 191 |
+
global index_type
|
| 192 |
+
index_type = idx_type
|
| 193 |
+
if idx_type == "ngt":
|
| 194 |
+
import ngtpy
|
| 195 |
+
index = ngtpy.Index(index_path, read_only=True)
|
| 196 |
+
elif idx_type == "faiss":
|
| 197 |
+
import faiss
|
| 198 |
+
index = faiss.read_index(index_path + "/index.faiss")
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError("Unknown index type")
|
| 201 |
with open(index_path + "/all_toks.json", "r") as f:
|
| 202 |
all_toks = json.loads(f.read())
|
| 203 |
|
|
|
|
| 213 |
enc_text = enc_text[-input_size:]
|
| 214 |
if use_rpc:
|
| 215 |
xq = vectorize_texts([enc_text])[-1]
|
| 216 |
+
if index_type == "ngt":
|
| 217 |
+
_id = index.search(xq, size=1, epsilon=1)[0][0]
|
| 218 |
+
else:
|
| 219 |
+
_id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
|
| 220 |
if all_toks[_id] in carry_toks:
|
| 221 |
tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
|
| 222 |
if tmp in enc_text:
|