shrut123 commited on
Commit
9a6f924
·
verified ·
1 Parent(s): 125fa26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -5
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import streamlit as st
3
  from pinecone import Pinecone
4
  from sentence_transformers import SentenceTransformer
 
5
 
6
  # Title of the Streamlit App
7
  st.title("Medical Hybrid Search")
@@ -13,7 +14,6 @@ index = None
13
  def initialize_pinecone():
14
  api_key = os.getenv('PINECONE_API_KEY') # Get Pinecone API key from environment variable
15
  if api_key:
16
- # Initialize Pinecone client using the new class instance method
17
  pc = Pinecone(api_key=api_key)
18
  return pc
19
  else:
@@ -23,7 +23,6 @@ def initialize_pinecone():
23
  # Function to connect to the 'pubmed-splade' index
24
  def connect_to_index(pc):
25
  index_name = 'pubmed-splade' # Hardcoded index name
26
- # Connect to the 'pubmed-splade' index
27
  if index_name in pc.list_indexes().names():
28
  index = pc.Index(index_name)
29
  return index
@@ -35,6 +34,17 @@ def connect_to_index(pc):
35
  def encode_query(model, query_text):
36
  return model.encode(query_text).tolist()
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Initialize Pinecone
39
  pc = initialize_pinecone()
40
 
@@ -49,14 +59,30 @@ if pc:
49
  # Query input
50
  query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?")
51
 
 
 
 
52
  # Button to encode query and search the Pinecone index
53
  if st.button("Search Query"):
54
  if query_text and index:
 
55
  dense_vector = encode_query(model, query_text)
56
-
 
 
 
 
 
 
 
 
 
 
 
57
  # Search the index
58
  results = index.query(
59
- vector=dense_vector,
 
60
  top_k=3,
61
  include_metadata=True
62
  )
@@ -64,7 +90,7 @@ if pc:
64
  st.write("### Search Results:")
65
  for match in results.matches:
66
  st.markdown(f"#### Score: **{match.score:.4f}**")
67
- st.write(f" #### Result: ** {match.metadata.get('context', 'No context available.')} **")
68
  st.write("---")
69
  else:
70
  st.error("Please enter a query and ensure the index is initialized.")
 
2
  import streamlit as st
3
  from pinecone import Pinecone
4
  from sentence_transformers import SentenceTransformer
5
+ import torch
6
 
7
  # Title of the Streamlit App
8
  st.title("Medical Hybrid Search")
 
14
  def initialize_pinecone():
15
  api_key = os.getenv('PINECONE_API_KEY') # Get Pinecone API key from environment variable
16
  if api_key:
 
17
  pc = Pinecone(api_key=api_key)
18
  return pc
19
  else:
 
23
  # Function to connect to the 'pubmed-splade' index
24
  def connect_to_index(pc):
25
  index_name = 'pubmed-splade' # Hardcoded index name
 
26
  if index_name in pc.list_indexes().names():
27
  index = pc.Index(index_name)
28
  return index
 
34
  def encode_query(model, query_text):
35
  return model.encode(query_text).tolist()
36
 
37
+ # Function to create hybrid scaled vectors
38
+ def hybrid_scale(dense, sparse, alpha):
39
+ if alpha < 0 or alpha > 1:
40
+ raise ValueError("Alpha must be between 0 and 1")
41
+ hsparse = {
42
+ 'indices': sparse['indices'],
43
+ 'values': [v * (1 - alpha) for v in sparse['values']]
44
+ }
45
+ hdense = [v * alpha for v in dense]
46
+ return hdense, hsparse
47
+
48
  # Initialize Pinecone
49
  pc = initialize_pinecone()
50
 
 
59
  # Query input
60
  query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?")
61
 
62
+ # Alpha input
63
+ alpha = st.slider("Set Alpha (for dense and sparse vector balancing)", 0.0, 1.0, 0.5)
64
+
65
  # Button to encode query and search the Pinecone index
66
  if st.button("Search Query"):
67
  if query_text and index:
68
+ # Encode query to get dense and sparse vectors
69
  dense_vector = encode_query(model, query_text)
70
+ input_ids = model.tokenizer(query_text, return_tensors='pt')
71
+ with torch.no_grad():
72
+ sparse_vector = sparse_model(d_kwargs=input_ids.to(device))['d_rep'].squeeze()
73
+
74
+ # Prepare sparse vector format for Pinecone
75
+ indices = sparse_vector.nonzero().squeeze().cpu().tolist()
76
+ values = sparse_vector[indices].cpu().tolist()
77
+ sparse_dict = {"indices": indices, "values": values}
78
+
79
+ # Scale dense and sparse vectors
80
+ hdense, hsparse = hybrid_scale(dense_vector, sparse_dict, alpha)
81
+
82
  # Search the index
83
  results = index.query(
84
+ vector=hdense,
85
+ sparse_vector=hsparse,
86
  top_k=3,
87
  include_metadata=True
88
  )
 
90
  st.write("### Search Results:")
91
  for match in results.matches:
92
  st.markdown(f"#### Score: **{match.score:.4f}**")
93
+ st.write(f" #### Context: {match.metadata.get('context', 'No context available.')}")
94
  st.write("---")
95
  else:
96
  st.error("Please enter a query and ensure the index is initialized.")