File size: 5,129 Bytes
0130713
c2f0c5c
 
 
f5dac9b
 
bbc879f
 
3d3fc58
 
 
 
f5dac9b
0130713
 
 
 
b60ea35
3d3fc58
0130713
c2f0c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5dac9b
 
 
 
 
 
 
 
865766a
f5dac9b
 
 
 
 
 
 
 
 
 
 
 
3d3fc58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
075510e
3d3fc58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
import pandas as pd
from langchain_text_splitters import TokenTextSplitter
from langchain.docstore.document import Document
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

device = 'cuda' if cuda.is_available() else 'cpu'


st.set_page_config(page_title="SEARCH IATI",layout='wide')
st.title("SEARCH IATI Database")
var=st.text_input("enter keyword")


def create_chunks(text):
    text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
    texts = text_splitter.split_text(text)
    return texts 

def get_chunks():
    orgas_df = pd.read_csv("iati_files/project_orgas.csv")
    region_df = pd.read_csv("iati_files/project_region.csv")
    sector_df = pd.read_csv("iati_files/project_sector.csv")
    status_df = pd.read_csv("iati_files/project_status.csv")
    texts_df = pd.read_csv("iati_files/project_texts.csv")

    projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner')
    projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner')
    projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner')
    projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner')
    giz_df = projects_df[projects_df.client.str.contains('bmz')].reset_index(drop=True)

    giz_df.drop(columns= ['orga_abbreviation', 'client',
        'orga_full_name', 'country', 
        'country_flag', 'crs_5_code', 'crs_3_code',
        'sgd_pred_code'], inplace=True)

    giz_df['text_size'] = giz_df.apply(lambda x: len((x['title_main'] + x['description_main']).split()), axis=1)
    giz_df['chunks'] = giz_df.apply(lambda x:create_chunks(x['title_main'] + x['description_main']),axis=1)
    giz_df = giz_df.explode(column=['chunks'], ignore_index=True)

        
    placeholder= []
    for i in range(len(giz_df)):
        placeholder.append(Document(page_content= giz_df.loc[i,'chunks'], 
                                metadata={"iati_id": giz_df.loc[i,'iati_id'],
                                        "iati_orga_id":giz_df.loc[i,'iati_orga_id'],
                                        "country_name":str(giz_df.loc[i,'country_name']),
                                        "crs_5_name": giz_df.loc[i,'crs_5_name'],
                                        "crs_3_name": giz_df.loc[i,'crs_3_name'],
                                        "sgd_pred_str":giz_df.loc[i,'sgd_pred_str'],
                                        "status":giz_df.loc[i,'status'],
                                        "title_main":giz_df.loc[i,'title_main'],}))
    return placeholder

def embed_chunks(chunks):
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': True},
        model_name='BAAI/bge-m3'
    )
    # placeholder for collection
    print("starting embedding")
    qdrant_collections = {}
    qdrant_collections['all'] = Qdrant.from_documents(
                chunks,
                embeddings,
                path="/data/local_qdrant",
                collection_name='all',
            )
            
    print(qdrant_collections)
    print("vector embeddings done")
    return qdrant_collections
    
def get_local_qdrant(): 
    """once the local qdrant server is created this is used to make the connection to exisitng server"""

    qdrant_collections = {}
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': True},
        model_name='BAAI/bge-m3')
    client = QdrantClient(path="/data/local_qdrant") 
    print("Collections in local Qdrant:",client.get_collections())
    qdrant_collections['all'] = Qdrant(client=client, collection_name='all', embeddings=embeddings, )
    return qdrant_collections    

def get_context(vectorstore,query):
    # create metadata filter


    # getting context
    retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", 
                                         search_kwargs={"score_threshold": 0.5, 
                                                        "k": 10,})
#    # re-ranking the retrieved results
#    model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL'))
#   compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K')))
#    compression_retriever = ContextualCompressionRetriever(
#           base_compressor=compressor, base_retriever=retriever
#       )
    context_retrieved = retriever.invoke(query)
    print(f"retrieved paragraphs:{len(context_retrieved)}")

    return context_retrieved
#chunks = get_chunks()
vectorstores = get_local_qdrant()    
button=st.button("search")
results= get_context(vectorstores, f"find the relvant paragraphs for: {var}")

st.write(results)