Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 6,095 Bytes
0130713 c2f0c5c f5dac9b bbc879f 3d3fc58 f5dac9b 0130713 b60ea35 3d3fc58 0130713 c2f0c5c b603692 c2f0c5c 76cab84 c2f0c5c 76cab84 c2f0c5c 76cab84 c2f0c5c 76cab84 c2f0c5c 76cab84 c2f0c5c f5dac9b 865766a f5dac9b 320a9e1 3d3fc58 075510e 3d3fc58 47df8ba 3d3fc58 7d478f8 a4730ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import streamlit as st
import pandas as pd
from langchain_text_splitters import TokenTextSplitter
from langchain.docstore.document import Document
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
device = 'cuda' if cuda.is_available() else 'cpu'
st.set_page_config(page_title="SEARCH IATI",layout='wide')
st.title("SEARCH IATI Database")
var=st.text_input("enter keyword")
def create_chunks(text):
"""TAKES A TEXT AND CERATES CREATES CHUNKS"""
text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
texts = text_splitter.split_text(text)
return texts
def get_chunks():
#orgas_df = pd.read_csv("iati_files/project_orgas.csv")
#region_df = pd.read_csv("iati_files/project_region.csv")
#sector_df = pd.read_csv("iati_files/project_sector.csv")
#status_df = pd.read_csv("iati_files/project_status.csv")
#texts_df = pd.read_csv("iati_files/project_texts.csv")
#projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner')
#projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner')
#projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner')
#projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner')
#giz_df = projects_df[projects_df.client.str.contains('bmz')].reset_index(drop=True)
#giz_df.drop(columns= ['orga_abbreviation', 'client',
# 'orga_full_name', 'country',
# 'country_flag', 'crs_5_code', 'crs_3_code',
# 'sgd_pred_code'], inplace=True)
giz_df = pd.read_json('iati_files/data_giz_website.json')
giz_df = giz_df.rename(columns={'content':'project_description'})
giz_df['text_size'] = giz_df.apply(lambda x: len((x['project_name'] + x['project_description']).split()), axis=1)
giz_df['chunks'] = giz_df.apply(lambda x:create_chunks(x['project_name'] + x['project_description']),axis=1)
giz_df = giz_df.explode(column=['chunks'], ignore_index=True)
placeholder= []
for i in range(len(giz_df)):
placeholder.append(Document(page_content= giz_df.loc[i,'chunks'],
metadata={
"title_main":giz_df.loc[i,'title_main'],
"country_name":str(giz_df.loc[i,'countries']),
"client": giz_df_new.loc[i, 'client'],
"language":giz_df_new.loc[i, 'language'],
"political_sponsor":giz_df.loc[i, 'poli_trager'],
"url": giz_df.loc[i, 'url']
#"iati_id": giz_df.loc[i,'iati_id'],
#"iati_orga_id":giz_df.loc[i,'iati_orga_id'],
#"crs_5_name": giz_df.loc[i,'crs_5_name'],
#"crs_3_name": giz_df.loc[i,'crs_3_name'],
#"sgd_pred_str":giz_df.loc[i,'sgd_pred_str'],
#"status":giz_df.loc[i,'status'],
}))
return placeholder
def embed_chunks(chunks):
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name='BAAI/bge-m3'
)
# placeholder for collection
print("starting embedding")
qdrant_collections = {}
qdrant_collections['all'] = Qdrant.from_documents(
chunks,
embeddings,
path="/data/local_qdrant",
collection_name='all',
)
print(qdrant_collections)
print("vector embeddings done")
return qdrant_collections
@st.cache_resource
def get_local_qdrant():
"""once the local qdrant server is created this is used to make the connection to exisitng server"""
qdrant_collections = {}
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name='BAAI/bge-m3')
client = QdrantClient(path="/data/local_qdrant")
print("Collections in local Qdrant:",client.get_collections())
qdrant_collections['all'] = Qdrant(client=client, collection_name='all', embeddings=embeddings, )
return qdrant_collections
def get_context(vectorstore,query):
# create metadata filter
# getting context
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
search_kwargs={"score_threshold": 0.5,
"k": 10,})
# # re-ranking the retrieved results
# model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL'))
# compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K')))
# compression_retriever = ContextualCompressionRetriever(
# base_compressor=compressor, base_retriever=retriever
# )
context_retrieved = retriever.invoke(query)
print(f"retrieved paragraphs:{len(context_retrieved)}")
return context_retrieved
#chunks = get_chunks()
vectorstores = get_local_qdrant()
vectorstore = vectorstores['all']
button=st.button("search")
results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
if button:
st.write(f"Found {len(results)} results for query:{var}")
for i in results:
st.subheader(i.metadata['iati_id']+":"+i.metadata['title_main'])
st.caption(f"Status:{i.metadata['status']}, Country:{i.metadata['country_name']}")
st.write(i.page_content)
st.divider()
|