Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import streamlit as st | |
import pandas as pd | |
from torch import cuda | |
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings | |
from langchain_community.vectorstores import Qdrant | |
from qdrant_client import QdrantClient | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import CrossEncoderReranker | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
from langchain_qdrant import FastEmbedSparse, RetrievalMode | |
from appStore.prep_data import process_giz_worldwide | |
# get the device to be used eithe gpu or cpu | |
device = 'cuda' if cuda.is_available() else 'cpu' | |
st.set_page_config(page_title="SEARCH IATI",layout='wide') | |
st.title("SEARCH IATI Database") | |
var=st.text_input("enter keyword") | |
def embed_chunks(chunks): | |
""" | |
takes the chunks and does the hybrid embedding for the list of chunks | |
""" | |
embeddings = HuggingFaceEmbeddings( | |
model_kwargs = {'device': device}, | |
encode_kwargs = {'normalize_embeddings': True}, | |
model_name='BAAI/bge-m3' | |
) | |
#sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25") | |
# placeholder for collection | |
print("starting embedding") | |
qdrant_collections = {} | |
qdrant_collections['all'] = Qdrant.from_documents( | |
chunks, | |
embeddings, | |
path="/data/local_qdrant", | |
collection_name='all', | |
) | |
print(qdrant_collections) | |
print("vector embeddings done") | |
def get_local_qdrant(): | |
"""once the local qdrant server is created this is used to make the connection to exisitng server""" | |
qdrant_collections = {} | |
embeddings = HuggingFaceEmbeddings( | |
model_kwargs = {'device': device}, | |
encode_kwargs = {'normalize_embeddings': True}, | |
model_name='BAAI/bge-m3') | |
client = QdrantClient(path="/data/local_qdrant") | |
print("Collections in local Qdrant:",client.get_collections()) | |
qdrant_collections['all'] = Qdrant(client=client, collection_name='all', embeddings=embeddings, ) | |
return qdrant_collections | |
def get_context(vectorstore,query): | |
# create metadata filter | |
# getting context | |
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", | |
search_kwargs={"score_threshold": 0.5, | |
"k": 10,}) | |
# # re-ranking the retrieved results | |
# model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL')) | |
# compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K'))) | |
# compression_retriever = ContextualCompressionRetriever( | |
# base_compressor=compressor, base_retriever=retriever | |
# ) | |
context_retrieved = retriever.invoke(query) | |
print(f"retrieved paragraphs:{len(context_retrieved)}") | |
return context_retrieved | |
# first we create the chunks for iati documents | |
chunks = process_giz_worldwide() | |
for i in range(5): | |
print(chunks.loc[0,'chunks']) | |
#print("chunking done") | |
# once the chunks are done, we perform hybrid emebddings | |
#embed_chunks(chunks) | |
# vectorstores = get_local_qdrant() | |
# vectorstore = vectorstores['all'] | |
# button=st.button("search") | |
# results= get_context(vectorstore, f"find the relvant paragraphs for: {var}") | |
if button: | |
st.write(f"Found {len(results)} results for query:{var}") | |
for i in results: | |
st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main'])) | |
st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}") | |
st.write(i.page_content) | |
st.divider() | |