Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import streamlit as st
|
3 |
from pinecone import Pinecone
|
4 |
from sentence_transformers import SentenceTransformer
|
|
|
5 |
|
6 |
# Title of the Streamlit App
|
7 |
st.title("Medical Hybrid Search")
|
@@ -13,7 +14,6 @@ index = None
|
|
13 |
def initialize_pinecone():
|
14 |
api_key = os.getenv('PINECONE_API_KEY') # Get Pinecone API key from environment variable
|
15 |
if api_key:
|
16 |
-
# Initialize Pinecone client using the new class instance method
|
17 |
pc = Pinecone(api_key=api_key)
|
18 |
return pc
|
19 |
else:
|
@@ -23,7 +23,6 @@ def initialize_pinecone():
|
|
23 |
# Function to connect to the 'pubmed-splade' index
|
24 |
def connect_to_index(pc):
|
25 |
index_name = 'pubmed-splade' # Hardcoded index name
|
26 |
-
# Connect to the 'pubmed-splade' index
|
27 |
if index_name in pc.list_indexes().names():
|
28 |
index = pc.Index(index_name)
|
29 |
return index
|
@@ -35,6 +34,17 @@ def connect_to_index(pc):
|
|
35 |
def encode_query(model, query_text):
|
36 |
return model.encode(query_text).tolist()
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# Initialize Pinecone
|
39 |
pc = initialize_pinecone()
|
40 |
|
@@ -49,14 +59,30 @@ if pc:
|
|
49 |
# Query input
|
50 |
query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?")
|
51 |
|
|
|
|
|
|
|
52 |
# Button to encode query and search the Pinecone index
|
53 |
if st.button("Search Query"):
|
54 |
if query_text and index:
|
|
|
55 |
dense_vector = encode_query(model, query_text)
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# Search the index
|
58 |
results = index.query(
|
59 |
-
vector=
|
|
|
60 |
top_k=3,
|
61 |
include_metadata=True
|
62 |
)
|
@@ -64,7 +90,7 @@ if pc:
|
|
64 |
st.write("### Search Results:")
|
65 |
for match in results.matches:
|
66 |
st.markdown(f"#### Score: **{match.score:.4f}**")
|
67 |
-
st.write(f" ####
|
68 |
st.write("---")
|
69 |
else:
|
70 |
st.error("Please enter a query and ensure the index is initialized.")
|
|
|
2 |
import streamlit as st
|
3 |
from pinecone import Pinecone
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
+
import torch
|
6 |
|
7 |
# Title of the Streamlit App
|
8 |
st.title("Medical Hybrid Search")
|
|
|
14 |
def initialize_pinecone():
|
15 |
api_key = os.getenv('PINECONE_API_KEY') # Get Pinecone API key from environment variable
|
16 |
if api_key:
|
|
|
17 |
pc = Pinecone(api_key=api_key)
|
18 |
return pc
|
19 |
else:
|
|
|
23 |
# Function to connect to the 'pubmed-splade' index
|
24 |
def connect_to_index(pc):
|
25 |
index_name = 'pubmed-splade' # Hardcoded index name
|
|
|
26 |
if index_name in pc.list_indexes().names():
|
27 |
index = pc.Index(index_name)
|
28 |
return index
|
|
|
34 |
def encode_query(model, query_text):
|
35 |
return model.encode(query_text).tolist()
|
36 |
|
37 |
+
# Function to create hybrid scaled vectors
|
38 |
+
def hybrid_scale(dense, sparse, alpha):
|
39 |
+
if alpha < 0 or alpha > 1:
|
40 |
+
raise ValueError("Alpha must be between 0 and 1")
|
41 |
+
hsparse = {
|
42 |
+
'indices': sparse['indices'],
|
43 |
+
'values': [v * (1 - alpha) for v in sparse['values']]
|
44 |
+
}
|
45 |
+
hdense = [v * alpha for v in dense]
|
46 |
+
return hdense, hsparse
|
47 |
+
|
48 |
# Initialize Pinecone
|
49 |
pc = initialize_pinecone()
|
50 |
|
|
|
59 |
# Query input
|
60 |
query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?")
|
61 |
|
62 |
+
# Alpha input
|
63 |
+
alpha = st.slider("Set Alpha (for dense and sparse vector balancing)", 0.0, 1.0, 0.5)
|
64 |
+
|
65 |
# Button to encode query and search the Pinecone index
|
66 |
if st.button("Search Query"):
|
67 |
if query_text and index:
|
68 |
+
# Encode query to get dense and sparse vectors
|
69 |
dense_vector = encode_query(model, query_text)
|
70 |
+
input_ids = model.tokenizer(query_text, return_tensors='pt')
|
71 |
+
with torch.no_grad():
|
72 |
+
sparse_vector = sparse_model(d_kwargs=input_ids.to(device))['d_rep'].squeeze()
|
73 |
+
|
74 |
+
# Prepare sparse vector format for Pinecone
|
75 |
+
indices = sparse_vector.nonzero().squeeze().cpu().tolist()
|
76 |
+
values = sparse_vector[indices].cpu().tolist()
|
77 |
+
sparse_dict = {"indices": indices, "values": values}
|
78 |
+
|
79 |
+
# Scale dense and sparse vectors
|
80 |
+
hdense, hsparse = hybrid_scale(dense_vector, sparse_dict, alpha)
|
81 |
+
|
82 |
# Search the index
|
83 |
results = index.query(
|
84 |
+
vector=hdense,
|
85 |
+
sparse_vector=hsparse,
|
86 |
top_k=3,
|
87 |
include_metadata=True
|
88 |
)
|
|
|
90 |
st.write("### Search Results:")
|
91 |
for match in results.matches:
|
92 |
st.markdown(f"#### Score: **{match.score:.4f}**")
|
93 |
+
st.write(f" #### Context: {match.metadata.get('context', 'No context available.')}")
|
94 |
st.write("---")
|
95 |
else:
|
96 |
st.error("Please enter a query and ensure the index is initialized.")
|