fizban99
commited on
Commit
·
eaee63c
1
Parent(s):
2da89ac
reranking added
Browse files- .gitignore +2 -0
- app.py +10 -2
- simiandb.py +2 -2
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
*.pyc
|
app.py
CHANGED
|
@@ -7,17 +7,25 @@ Created on Wed Mar 22 19:59:54 2023
|
|
| 7 |
import gradio as gr
|
| 8 |
from simiandb import Simiandb
|
| 9 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
model_name = "all-MiniLM-L6-v2"
|
| 15 |
hf = HuggingFaceEmbeddings(model_name=model_name)
|
|
|
|
| 16 |
|
| 17 |
documentdb = Simiandb("mystore", embedding_function=hf, mode="a")
|
| 18 |
|
| 19 |
def search(query):
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
iface = gr.Interface(fn=search, inputs="text", outputs="text")
|
| 23 |
-
iface.launch()
|
|
|
|
|
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
from simiandb import Simiandb
|
| 9 |
from langchain.embeddings import HuggingFaceEmbeddings
|
| 10 |
+
from sentence_transformers import CrossEncoder
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
|
| 15 |
model_name = "all-MiniLM-L6-v2"
|
| 16 |
hf = HuggingFaceEmbeddings(model_name=model_name)
|
| 17 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
| 18 |
|
| 19 |
documentdb = Simiandb("mystore", embedding_function=hf, mode="a")
|
| 20 |
|
| 21 |
def search(query):
|
| 22 |
+
hits = documentdb.similarity_search(query)
|
| 23 |
+
cross_inp = [[query, hit] for hit in hits]
|
| 24 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
| 25 |
+
hits = [hit for _, hit in sorted(zip(cross_scores, hits), reverse=True)]
|
| 26 |
+
return hits[0]
|
| 27 |
|
| 28 |
iface = gr.Interface(fn=search, inputs="text", outputs="text")
|
| 29 |
+
iface.launch()
|
| 30 |
+
|
| 31 |
+
#print(search("what is the balloon boy hoax"))
|
simiandb.py
CHANGED
|
@@ -178,7 +178,7 @@ class Simiandb():
|
|
| 178 |
batch = self._vector_table.chunkshape[0]*25
|
| 179 |
res = np.ascontiguousarray(np.empty(shape=(count,), dtype="float32"))
|
| 180 |
end = 0
|
| 181 |
-
|
| 182 |
while end!=count:
|
| 183 |
end += batch
|
| 184 |
end = end if end <= count else count
|
|
@@ -189,7 +189,7 @@ class Simiandb():
|
|
| 189 |
|
| 190 |
indices = np.argpartition(res, -k)[-k:] #from https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
|
| 191 |
indices = indices[np.argsort(res[indices])[::-1]]
|
| 192 |
-
|
| 193 |
return indices
|
| 194 |
|
| 195 |
|
|
|
|
| 178 |
batch = self._vector_table.chunkshape[0]*25
|
| 179 |
res = np.ascontiguousarray(np.empty(shape=(count,), dtype="float32"))
|
| 180 |
end = 0
|
| 181 |
+
|
| 182 |
while end!=count:
|
| 183 |
end += batch
|
| 184 |
end = end if end <= count else count
|
|
|
|
| 189 |
|
| 190 |
indices = np.argpartition(res, -k)[-k:] #from https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
|
| 191 |
indices = indices[np.argsort(res[indices])[::-1]]
|
| 192 |
+
|
| 193 |
return indices
|
| 194 |
|
| 195 |
|