ppsingh commited on
Commit
3d3fc58
·
verified ·
1 Parent(s): 5f9ca3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -4
app.py CHANGED
@@ -6,13 +6,17 @@ from torch import cuda
6
  from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
7
  from langchain_community.vectorstores import Qdrant
8
  from qdrant_client import QdrantClient
 
 
 
 
9
  device = 'cuda' if cuda.is_available() else 'cpu'
10
 
11
 
12
  st.set_page_config(page_title="SEARCH IATI",layout='wide')
13
  st.title("SEARCH IATI Database")
14
  var=st.text_input("enter keyword")
15
- title = var.replace(' ','+')
16
 
17
  def create_chunks(text):
18
  text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
@@ -75,6 +79,41 @@ def embed_chunks(chunks):
75
  print("vector embeddings done")
76
  return qdrant_collections
77
 
78
- chunks = get_chunks()
79
- qdrant_col = embed_chunks(chunks)
80
- st.write("Success")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
7
  from langchain_community.vectorstores import Qdrant
8
  from qdrant_client import QdrantClient
9
+ from langchain.retrievers import ContextualCompressionRetriever
10
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
11
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
12
+
13
  device = 'cuda' if cuda.is_available() else 'cpu'
14
 
15
 
16
  st.set_page_config(page_title="SEARCH IATI",layout='wide')
17
  st.title("SEARCH IATI Database")
18
  var=st.text_input("enter keyword")
19
+
20
 
21
  def create_chunks(text):
22
  text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
 
79
  print("vector embeddings done")
80
  return qdrant_collections
81
 
82
+ def get_local_qdrant():
83
+ """once the local qdrant server is created this is used to make the connection to exisitng server"""
84
+
85
+ qdrant_collections = {}
86
+ embeddings = HuggingFaceEmbeddings(
87
+ model_kwargs = {'device': device},
88
+ encode_kwargs = {'normalize_embeddings': True},
89
+ model_name='BAAI/bge-m3')
90
+ client = QdrantClient(path="/data/local_qdrant")
91
+ print("Collections in local Qdrant:",client.get_collections())
92
+ qdrant_collections['all'] = Qdrant(client=client, collection_name='all', embeddings=embeddings, )
93
+ return qdrant_collections
94
+
95
+ def get_context(vectorstore,query):
96
+ # create metadata filter
97
+
98
+
99
+ # getting context
100
+ retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
101
+ search_kwargs={"score_threshold": 0.5,
102
+ "k": 10),
103
+ })
104
+ # # re-ranking the retrieved results
105
+ # model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL'))
106
+ # compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K')))
107
+ # compression_retriever = ContextualCompressionRetriever(
108
+ # base_compressor=compressor, base_retriever=retriever
109
+ # )
110
+ context_retrieved = retriever.invoke(query)
111
+ print(f"retrieved paragraphs:{len(context_retrieved)}")
112
+
113
+ return context_retrieved
114
+ #chunks = get_chunks()
115
+ vectorstores = get_local_qdrant()
116
+ button=st.button("search")
117
+ results= get_context(vectorstores, f"find the relvant paragraphs for: {var}")
118
+
119
+ st.write(results)