kaisugi's picture
update
e8c441c
raw
history blame
3.11 kB
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
from transformers import AutoModel, AutoTokenizer
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
model = AutoModel.from_pretrained("kaisugi/scitoricsbert")
model.eval()
return model, tokenizer
@st.cache(allow_output_mutation=True)
def load_sentence_data():
sentence_df = pd.read_csv("sentence_data_789k.csv.gz")
return sentence_df
@st.cache(allow_output_mutation=True)
def load_sentence_embeddings():
npz_comp = np.load("sentence_embeddings_789k.npz")
sentence_embeddings = npz_comp["arr_0"]
return sentence_embeddings
@st.cache(allow_output_mutation=True)
def build_faiss_index(sentence_emeddings):
D = 768
N = 789188
Xt = sentence_emeddings[:39000]
X = sentence_emeddings
# Param of PQ
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
# Param of IVF
nlist = 1000 # The number of cells (space partition). Typical value is sqrt(N)
# Param of HNSW
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
# Setup
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
# Train
index.train(Xt)
# Add
index.add(X)
# Search
index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
return index
@st.cache
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
with torch.no_grad():
inputs = tokenizer.encode_plus(
input_text,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
outputs = model(**inputs)
query_embeddings = outputs.last_hidden_state[:, 0, :][0]
query_embeddings = query_embeddings.detach().cpu().numpy()
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
print(np.array([query_embeddings]))
dists, ids = index.search(x=np.array([query_embeddings]), k=top_k)
print(dists)
print(ids)
def main(model, tokenizer, sentence_df, sentence_embeddings, index):
st.markdown("## AI-based Paraphrasing for Academic Writing")
input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
top_k = st.number_input('top_k', min_value=1, value=10, step=1)
get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer()
sentence_df = load_sentence_data()
sentence_emeddings = load_sentence_embeddings()
faiss.normalize_L2(sentence_emeddings)
index = build_faiss_index(sentence_emeddings)
main(model, tokenizer, sentence_df, sentence_emeddings, index)