File size: 2,218 Bytes
c24fab8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43478d2
c7ced1a
 
1b29f8c
 
 
c7ced1a
c24fab8
c7ced1a
 
 
 
 
 
 
 
c24fab8
c7ced1a
 
 
c24fab8
c7ced1a
 
 
 
43478d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import pinecone
import requests
import streamlit as st
import torch

from transformers import AutoTokenizer, AutoModel

from config import config
    
    
def search(text: str, k: int = 5):
    """Get the k closest articles to the text."""
    embeds = _get_embeddings(text)
    
    r = requests.post(
        f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query",
        headers={
            "Api-Key": config.pinecone_api_key,
            "accept": "application/json",
            "content-type": "application/json",
        },
        json={
            "vector": embeds,
            "top_k": k,
            "includeMetadata": True,
            "includeValues": False,
        },
    )
    
    if r.status_code == 200:
        return r.json()
    else:
        raise Exception(f"Error: {r.status_code} - {r.text}")

    
def _get_embeddings(text: str):
    inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    
    with torch.no_grad():
        last_hidden_states = st.session_state.model(**inputs_ids)[0]
    
    return last_hidden_states.mean(dim=1).squeeze().tolist()



password = st.text_input("Password", type="password")
if password == config.password:
    st.title("PubMed Embeddings")
    st.subheader("Search for a PubMed article and get its id.")

    text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19")

    with st.spinner("Loading Embedding Model..."):
        pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env)
        if "index" not in st.session_state:
            st.session_state.index = pinecone.Index(config.pinecone_index)
        if "tokenizer" not in st.session_state:          
            st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        if "model" not in st.session_state:
            st.session_state.model = AutoModel.from_pretrained(config.model_name)

    if st.button("Search"):
        with st.spinner("Searching..."):
            results = search(text)

        for res in results["matches"]:
            st.write(f"{res['id']} - confidence: {res['score']:.2f}")
else:
    st.write("Password incorrect!")