kaisugi's picture
update
210c6db
raw
history blame
2.29 kB
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
from transformers import AutoModel, AutoTokenizer
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2")
model = AutoModel.from_pretrained("kaisugi/anlp_embedding_model")
model.eval()
return model, tokenizer
@st.cache(allow_output_mutation=True)
def load_title_data():
title_df = pd.read_csv("anlp2023.csv")
return title_df
@st.cache(allow_output_mutation=True)
def load_title_embeddings():
npz_comp = np.load("anlp_title_embeddings.npz")
title_embeddings = npz_comp["arr_0"]
return title_embeddings
@st.cache
def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
with torch.no_grad():
inputs = tokenizer.encode_plus(
input_text,
padding=True,
truncation="only_second",
return_tensors="pt",
max_length=512,
)
outputs = model(**inputs)
query_embeddings = outputs.last_hidden_state[:, 0, :][0]
query_embeddings = query_embeddings.detach().cpu().numpy()
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
retrieved_titles = []
retrieved_pids = []
for id in ids[0]:
retrieved_titles.append(title_df.loc[id, "title"])
retrieved_pids.append(title_df.loc[id, "pid"])
df = pd.DataFrame({"pids": retrieved_pids, "paper": retrieved_titles})
return df
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer()
title_df = load_title_data()
title_embeddings = load_title_embeddings()
index = faiss.IndexFlatL2(768)
index.add(title_embeddings)
st.markdown("## NLP2023 論文検索")
input_text = st.text_input('input', '', placeholder='ここに論文のタイトルを入力してください')
top_k = st.number_input('top_k', min_value=1, value=10, step=1)
if st.button('検索'):
stripped_input_text = input_text.strip()
df = get_retrieval_results(index, stripped_input_text, top_k, tokenizer, title_df)
st.table(df)