Spaces:
Runtime error
Runtime error
Update rpc.py
Browse files
rpc.py
CHANGED
@@ -149,16 +149,13 @@ class DiffAttention(keras.layers.Layer):
|
|
149 |
att = att1 - lambda_full * att2
|
150 |
|
151 |
outi = att @ v
|
152 |
-
attp = self.roll_embeddings(att, self.range_undo)
|
153 |
-
outp = attp @ pos_src
|
154 |
-
out = outi + outp
|
155 |
out = out * (1 - self.lambda_init)
|
156 |
return out
|
157 |
|
158 |
|
159 |
# Import Model
|
160 |
model = keras.models.load_model(
|
161 |
-
"
|
162 |
custom_objects={
|
163 |
"DiffAttention" : DiffAttention,
|
164 |
"SharedEmbedding" : SharedEmbedding,
|
@@ -187,13 +184,20 @@ def vectorize_texts(all_texts):
|
|
187 |
# Import Database and All Toks
|
188 |
index = None
|
189 |
all_toks = None
|
190 |
-
|
|
|
191 |
global index
|
192 |
global all_toks
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
with open(index_path + "/all_toks.json", "r") as f:
|
198 |
all_toks = json.loads(f.read())
|
199 |
|
@@ -209,8 +213,10 @@ def generate(text, use_rpc=True, max_tokens=128):
|
|
209 |
enc_text = enc_text[-input_size:]
|
210 |
if use_rpc:
|
211 |
xq = vectorize_texts([enc_text])[-1]
|
212 |
-
|
213 |
-
|
|
|
|
|
214 |
if all_toks[_id] in carry_toks:
|
215 |
tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
|
216 |
if tmp in enc_text:
|
|
|
149 |
att = att1 - lambda_full * att2
|
150 |
|
151 |
outi = att @ v
|
|
|
|
|
|
|
152 |
out = out * (1 - self.lambda_init)
|
153 |
return out
|
154 |
|
155 |
|
156 |
# Import Model
|
157 |
model = keras.models.load_model(
|
158 |
+
"rpc.keras",
|
159 |
custom_objects={
|
160 |
"DiffAttention" : DiffAttention,
|
161 |
"SharedEmbedding" : SharedEmbedding,
|
|
|
184 |
# Import Database and All Toks
|
185 |
index = None
|
186 |
all_toks = None
|
187 |
+
index_type = None
|
188 |
+
def load_index(index_path="/dev/shm/rpc-vecdb/index", idx_type="ngt"):
|
189 |
global index
|
190 |
global all_toks
|
191 |
+
global index_type
|
192 |
+
index_type = idx_type
|
193 |
+
if idx_type == "ngt":
|
194 |
+
import ngtpy
|
195 |
+
index = ngtpy.Index(index_path, read_only=True)
|
196 |
+
elif idx_type == "faiss":
|
197 |
+
import faiss
|
198 |
+
index = faiss.read_index(index_path + "/index.faiss")
|
199 |
+
else:
|
200 |
+
raise ValueError("Unknown index type")
|
201 |
with open(index_path + "/all_toks.json", "r") as f:
|
202 |
all_toks = json.loads(f.read())
|
203 |
|
|
|
213 |
enc_text = enc_text[-input_size:]
|
214 |
if use_rpc:
|
215 |
xq = vectorize_texts([enc_text])[-1]
|
216 |
+
if index_type == "ngt":
|
217 |
+
_id = index.search(xq, size=1, epsilon=1)[0][0]
|
218 |
+
else:
|
219 |
+
_id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
|
220 |
if all_toks[_id] in carry_toks:
|
221 |
tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
|
222 |
if tmp in enc_text:
|