File size: 5,479 Bytes
0130713
c2f0c5c
 
 
f5dac9b
 
bbc879f
 
3d3fc58
 
 
 
f5dac9b
0130713
 
 
 
b60ea35
3d3fc58
0130713
c2f0c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5dac9b
 
 
 
 
 
 
 
865766a
f5dac9b
 
 
 
 
 
 
 
 
 
 
 
320a9e1
3d3fc58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
075510e
3d3fc58
 
 
 
 
 
 
 
 
 
 
47df8ba
 
3d3fc58
7d478f8
a4730ef
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
import pandas as pd
from langchain_text_splitters import TokenTextSplitter
from langchain.docstore.document import Document
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

device = 'cuda' if cuda.is_available() else 'cpu'


st.set_page_config(page_title="SEARCH IATI",layout='wide')
st.title("SEARCH IATI Database")
var=st.text_input("enter keyword")


def create_chunks(text):
    text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
    texts = text_splitter.split_text(text)
    return texts 

def get_chunks():
    orgas_df = pd.read_csv("iati_files/project_orgas.csv")
    region_df = pd.read_csv("iati_files/project_region.csv")
    sector_df = pd.read_csv("iati_files/project_sector.csv")
    status_df = pd.read_csv("iati_files/project_status.csv")
    texts_df = pd.read_csv("iati_files/project_texts.csv")

    projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner')
    projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner')
    projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner')
    projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner')
    giz_df = projects_df[projects_df.client.str.contains('bmz')].reset_index(drop=True)

    giz_df.drop(columns= ['orga_abbreviation', 'client',
        'orga_full_name', 'country', 
        'country_flag', 'crs_5_code', 'crs_3_code',
        'sgd_pred_code'], inplace=True)

    giz_df['text_size'] = giz_df.apply(lambda x: len((x['title_main'] + x['description_main']).split()), axis=1)
    giz_df['chunks'] = giz_df.apply(lambda x:create_chunks(x['title_main'] + x['description_main']),axis=1)
    giz_df = giz_df.explode(column=['chunks'], ignore_index=True)

        
    placeholder= []
    for i in range(len(giz_df)):
        placeholder.append(Document(page_content= giz_df.loc[i,'chunks'], 
                                metadata={"iati_id": giz_df.loc[i,'iati_id'],
                                        "iati_orga_id":giz_df.loc[i,'iati_orga_id'],
                                        "country_name":str(giz_df.loc[i,'country_name']),
                                        "crs_5_name": giz_df.loc[i,'crs_5_name'],
                                        "crs_3_name": giz_df.loc[i,'crs_3_name'],
                                        "sgd_pred_str":giz_df.loc[i,'sgd_pred_str'],
                                        "status":giz_df.loc[i,'status'],
                                        "title_main":giz_df.loc[i,'title_main'],}))
    return placeholder

def embed_chunks(chunks):
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': True},
        model_name='BAAI/bge-m3'
    )
    # placeholder for collection
    print("starting embedding")
    qdrant_collections = {}
    qdrant_collections['all'] = Qdrant.from_documents(
                chunks,
                embeddings,
                path="/data/local_qdrant",
                collection_name='all',
            )
            
    print(qdrant_collections)
    print("vector embeddings done")
    return qdrant_collections
    
@st.cache_resource    
def get_local_qdrant(): 
    """once the local qdrant server is created this is used to make the connection to exisitng server"""

    qdrant_collections = {}
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': True},
        model_name='BAAI/bge-m3')
    client = QdrantClient(path="/data/local_qdrant") 
    print("Collections in local Qdrant:",client.get_collections())
    qdrant_collections['all'] = Qdrant(client=client, collection_name='all', embeddings=embeddings, )
    return qdrant_collections    

def get_context(vectorstore,query):
    # create metadata filter


    # getting context
    retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", 
                                         search_kwargs={"score_threshold": 0.5, 
                                                        "k": 10,})
#    # re-ranking the retrieved results
#    model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL'))
#   compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K')))
#    compression_retriever = ContextualCompressionRetriever(
#           base_compressor=compressor, base_retriever=retriever
#       )
    context_retrieved = retriever.invoke(query)
    print(f"retrieved paragraphs:{len(context_retrieved)}")

    return context_retrieved
#chunks = get_chunks()
vectorstores = get_local_qdrant() 
vectorstore = vectorstores['all']
button=st.button("search")
results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
if button:
    st.write(f"Found {len(results)} results for query:{var}")

    for i in results: 
        st.subheader(i.metadata['iati_id']+":"+i.metadata['title_main'])
        st.caption(f"Status:{i.metadata['status']}, Country:{i.metadata['country_name']}")
        st.write(i.page_content)
        st.divider()