Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update app.py
Browse files
app.py
CHANGED
@@ -6,13 +6,17 @@ from torch import cuda
|
|
6 |
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
|
7 |
from langchain_community.vectorstores import Qdrant
|
8 |
from qdrant_client import QdrantClient
|
|
|
|
|
|
|
|
|
9 |
device = 'cuda' if cuda.is_available() else 'cpu'
|
10 |
|
11 |
|
12 |
st.set_page_config(page_title="SEARCH IATI",layout='wide')
|
13 |
st.title("SEARCH IATI Database")
|
14 |
var=st.text_input("enter keyword")
|
15 |
-
|
16 |
|
17 |
def create_chunks(text):
|
18 |
text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
|
@@ -75,6 +79,41 @@ def embed_chunks(chunks):
|
|
75 |
print("vector embeddings done")
|
76 |
return qdrant_collections
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
|
7 |
from langchain_community.vectorstores import Qdrant
|
8 |
from qdrant_client import QdrantClient
|
9 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
10 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
11 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
12 |
+
|
13 |
device = 'cuda' if cuda.is_available() else 'cpu'
|
14 |
|
15 |
|
16 |
st.set_page_config(page_title="SEARCH IATI",layout='wide')
|
17 |
st.title("SEARCH IATI Database")
|
18 |
var=st.text_input("enter keyword")
|
19 |
+
|
20 |
|
21 |
def create_chunks(text):
|
22 |
text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
|
|
|
79 |
print("vector embeddings done")
|
80 |
return qdrant_collections
|
81 |
|
82 |
+
def get_local_qdrant():
|
83 |
+
"""once the local qdrant server is created this is used to make the connection to exisitng server"""
|
84 |
+
|
85 |
+
qdrant_collections = {}
|
86 |
+
embeddings = HuggingFaceEmbeddings(
|
87 |
+
model_kwargs = {'device': device},
|
88 |
+
encode_kwargs = {'normalize_embeddings': True},
|
89 |
+
model_name='BAAI/bge-m3')
|
90 |
+
client = QdrantClient(path="/data/local_qdrant")
|
91 |
+
print("Collections in local Qdrant:",client.get_collections())
|
92 |
+
qdrant_collections['all'] = Qdrant(client=client, collection_name='all', embeddings=embeddings, )
|
93 |
+
return qdrant_collections
|
94 |
+
|
95 |
+
def get_context(vectorstore,query):
|
96 |
+
# create metadata filter
|
97 |
+
|
98 |
+
|
99 |
+
# getting context
|
100 |
+
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
|
101 |
+
search_kwargs={"score_threshold": 0.5,
|
102 |
+
"k": 10),
|
103 |
+
})
|
104 |
+
# # re-ranking the retrieved results
|
105 |
+
# model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL'))
|
106 |
+
# compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K')))
|
107 |
+
# compression_retriever = ContextualCompressionRetriever(
|
108 |
+
# base_compressor=compressor, base_retriever=retriever
|
109 |
+
# )
|
110 |
+
context_retrieved = retriever.invoke(query)
|
111 |
+
print(f"retrieved paragraphs:{len(context_retrieved)}")
|
112 |
+
|
113 |
+
return context_retrieved
|
114 |
+
#chunks = get_chunks()
|
115 |
+
vectorstores = get_local_qdrant()
|
116 |
+
button=st.button("search")
|
117 |
+
results= get_context(vectorstores, f"find the relvant paragraphs for: {var}")
|
118 |
+
|
119 |
+
st.write(results)
|