kaisugi's picture
update
e1a3f25
raw
history blame
3.05 kB
from transformers import AutoModel, AutoTokenizer
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
model = AutoModel.from_pretrained("kaisugi/scitoricsbert")
model.eval()
return model, tokenizer
@st.cache(allow_output_mutation=True)
def load_sentence_data():
sentence_df = pd.read_csv("sentence_data_789k.csv.gz")
return sentence_df
@st.cache(allow_output_mutation=True)
def load_sentence_embeddings_and_index():
npz_comp = np.load("sentence_embeddings_789k.npz")
sentence_embeddings = npz_comp["arr_0"]
faiss.normalize_L2(sentence_embeddings)
D = 768
N = 789188
Xt = sentence_embeddings[:39000]
X = sentence_embeddings
# Param of PQ
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
# Param of IVF
nlist = 888 # The number of cells (space partition). Typical value is sqrt(N)
# Param of HNSW
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
# Setup
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
# Train
index.train(Xt)
# Add
index.add(X)
# Search
index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
return sentence_embeddings, index
@st.cache(allow_output_mutation=True)
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
with torch.no_grad():
inputs = tokenizer.encode_plus(
input_text,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
outputs = model(**inputs)
query_embeddings = outputs.last_hidden_state[:, 0, :][0]
query_embeddings = query_embeddings.detach().cpu().numpy()
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
retrieved_sentences = []
for id in ids[0]:
retrieved_sentences.append(sentence_df.loc[id, "sentence"])
return pd.DataFrame({"sentences": retrieved_sentences})
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer()
sentence_df = load_sentence_data()
sentence_embeddings, index = load_sentence_embeddings_and_index()
st.markdown("## AI-based Paraphrasing for Academic Writing")
input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
top_k = st.number_input('top_k', min_value=1, value=10, step=1)
if st.button('search'):
df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
st.table(df)