shrut123 commited on
Commit
dd92eea
·
verified ·
1 Parent(s): 9a6f924

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -3,6 +3,8 @@ import streamlit as st
3
  from pinecone import Pinecone
4
  from sentence_transformers import SentenceTransformer
5
  import torch
 
 
6
 
7
  # Title of the Streamlit App
8
  st.title("Medical Hybrid Search")
@@ -55,6 +57,12 @@ if pc:
55
 
56
  # Model for query encoding
57
  model = SentenceTransformer('msmarco-bert-base-dot-v5')
 
 
 
 
 
 
58
 
59
  # Query input
60
  query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?")
@@ -67,7 +75,7 @@ if pc:
67
  if query_text and index:
68
  # Encode query to get dense and sparse vectors
69
  dense_vector = encode_query(model, query_text)
70
- input_ids = model.tokenizer(query_text, return_tensors='pt')
71
  with torch.no_grad():
72
  sparse_vector = sparse_model(d_kwargs=input_ids.to(device))['d_rep'].squeeze()
73
 
@@ -90,7 +98,7 @@ if pc:
90
  st.write("### Search Results:")
91
  for match in results.matches:
92
  st.markdown(f"#### Score: **{match.score:.4f}**")
93
- st.write(f" #### Context: {match.metadata.get('context', 'No context available.')}")
94
  st.write("---")
95
  else:
96
  st.error("Please enter a query and ensure the index is initialized.")
 
3
  from pinecone import Pinecone
4
  from sentence_transformers import SentenceTransformer
5
  import torch
6
+ from splade.models.transformer_rep import Splade
7
+ from transformers import AutoTokenizer
8
 
9
  # Title of the Streamlit App
10
  st.title("Medical Hybrid Search")
 
57
 
58
  # Model for query encoding
59
  model = SentenceTransformer('msmarco-bert-base-dot-v5')
60
+
61
+ # Initialize sparse model and tokenizer
62
+ sparse_model_id = 'naver/splade-cocondenser-ensembledistil'
63
+ sparse_model = Splade(sparse_model_id, agg='max')
64
+ sparse_model.eval() # Set the model to evaluation mode
65
+ tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)
66
 
67
  # Query input
68
  query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?")
 
75
  if query_text and index:
76
  # Encode query to get dense and sparse vectors
77
  dense_vector = encode_query(model, query_text)
78
+ input_ids = tokenizer(query_text, return_tensors='pt')
79
  with torch.no_grad():
80
  sparse_vector = sparse_model(d_kwargs=input_ids.to(device))['d_rep'].squeeze()
81
 
 
98
  st.write("### Search Results:")
99
  for match in results.matches:
100
  st.markdown(f"#### Score: **{match.score:.4f}**")
101
+ st.write(f"####Context:{match.metadata.get('context', 'No context available.')}")
102
  st.write("---")
103
  else:
104
  st.error("Please enter a query and ensure the index is initialized.")