ppsingh commited on
Commit
c567921
·
1 Parent(s): cb359de

add search

Browse files
Files changed (2) hide show
  1. app.py +12 -7
  2. appStore/search.py +47 -2
app.py CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
3
  from appStore.prep_data import process_giz_worldwide
4
  from appStore.prep_utils import create_documents, get_client
5
  from appStore.embed import hybrid_embed_chunks
 
6
  from torch import cuda
7
  # get the device to be used eithe gpu or cpu
8
  device = 'cuda' if cuda.is_available() else 'cpu'
@@ -19,12 +20,14 @@ var=st.text_input("enter keyword")
19
  ##### Convert to langchain documents
20
  #temp_doc = create_documents(chunks,'chunks')
21
  ##### Embed and store docs, check if collection exist then you need to update the collection
22
- #collection_name = "giz_worldwide"
23
  #hybrid_embed_chunks(docs= temp_doc, collection_name = collection_name)
24
 
25
  ################### Hybrid Search ######################################################
26
  client = get_client()
27
  print(client.get_collections())
 
 
28
 
29
 
30
  button=st.button("search")
@@ -32,10 +35,12 @@ button=st.button("search")
32
  #print(found_docs)
33
  # results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
34
  if button:
35
- st.write(f"Found {len(results)} results for query:{var}")
 
 
36
 
37
- for i in results:
38
- st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
39
- st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
40
- st.write(i.page_content)
41
- st.divider()
 
3
  from appStore.prep_data import process_giz_worldwide
4
  from appStore.prep_utils import create_documents, get_client
5
  from appStore.embed import hybrid_embed_chunks
6
+ from appStore.search import hybrid_search
7
  from torch import cuda
8
  # get the device to be used eithe gpu or cpu
9
  device = 'cuda' if cuda.is_available() else 'cpu'
 
20
  ##### Convert to langchain documents
21
  #temp_doc = create_documents(chunks,'chunks')
22
  ##### Embed and store docs, check if collection exist then you need to update the collection
23
+ collection_name = "giz_worldwide"
24
  #hybrid_embed_chunks(docs= temp_doc, collection_name = collection_name)
25
 
26
  ################### Hybrid Search ######################################################
27
  client = get_client()
28
  print(client.get_collections())
29
+ results = hybrid_search(client, var, collection_name)
30
+
31
 
32
 
33
  button=st.button("search")
 
35
  #print(found_docs)
36
  # results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
37
  if button:
38
+ st.write(f"Showing Top 10 results for query:{var}")
39
+ st.write(f"Semantic: {len(results[0])}")
40
+ st.write(f"Semantic: {len(results[1])}")
41
 
42
+ # for i in results:
43
+ # st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
44
+ # st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
45
+ # st.write(i.page_content)
46
+ # st.divider()
appStore/search.py CHANGED
@@ -1,5 +1,50 @@
1
  from appStore.prep_utils import get_client
 
 
 
 
2
 
3
- def hybrid_search(client, query):
4
- print("wip")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
1
  from appStore.prep_utils import get_client
2
+ from langchain_qdrant import FastEmbedSparse, RetrievalMode
3
+ from torch import cuda
4
+ # get the device to be used eithe gpu or cpu
5
+ device = 'cuda' if cuda.is_available() else 'cpu'
6
 
7
+ import streamlit as st
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+
10
+ def hybrid_search(client, query, collection_name):
11
+ embeddings = HuggingFaceEmbeddings(
12
+ model_kwargs = {'device': device},
13
+ encode_kwargs = {'normalize_embeddings': True},
14
+ model_name='BAAI/bge-m3'
15
+ )
16
+
17
+ sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")
18
+
19
+ # embed query
20
+ q_dense = embeddings.embed_query(query)
21
+ q_sparse = sparse_embeddings.embed_query(query)
22
+
23
+ results = client.search_batch(collection_name=collection_name,
24
+ requests=[
25
+ models.SearchRequest(
26
+ vector=models.NamedVector(
27
+ name="text-dense",
28
+ vector=q_dense,
29
+ ),
30
+ limit=10,
31
+ ),
32
+ models.SearchRequest(
33
+ vector=models.NamedSparseVector(
34
+ name="text-sparse",
35
+ vector=models.SparseVector(
36
+ indices=q_sparse.indices,
37
+ values=q_sparse.values,
38
+ ),
39
+ ),
40
+ limit=10,
41
+ ),
42
+ ],)
43
+
44
+
45
+
46
+
47
+
48
+ print(results)
49
+ return results
50