File size: 2,218 Bytes
c24fab8 43478d2 c7ced1a 1b29f8c c7ced1a c24fab8 c7ced1a c24fab8 c7ced1a c24fab8 c7ced1a 43478d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import pinecone
import requests
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
from config import config
def search(text: str, k: int = 5):
"""Get the k closest articles to the text."""
embeds = _get_embeddings(text)
r = requests.post(
f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query",
headers={
"Api-Key": config.pinecone_api_key,
"accept": "application/json",
"content-type": "application/json",
},
json={
"vector": embeds,
"top_k": k,
"includeMetadata": True,
"includeValues": False,
},
)
if r.status_code == 200:
return r.json()
else:
raise Exception(f"Error: {r.status_code} - {r.text}")
def _get_embeddings(text: str):
inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
last_hidden_states = st.session_state.model(**inputs_ids)[0]
return last_hidden_states.mean(dim=1).squeeze().tolist()
password = st.text_input("Password", type="password")
if password == config.password:
st.title("PubMed Embeddings")
st.subheader("Search for a PubMed article and get its id.")
text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19")
with st.spinner("Loading Embedding Model..."):
pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env)
if "index" not in st.session_state:
st.session_state.index = pinecone.Index(config.pinecone_index)
if "tokenizer" not in st.session_state:
st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
if "model" not in st.session_state:
st.session_state.model = AutoModel.from_pretrained(config.model_name)
if st.button("Search"):
with st.spinner("Searching..."):
results = search(text)
for res in results["matches"]:
st.write(f"{res['id']} - confidence: {res['score']:.2f}")
else:
st.write("Password incorrect!")
|