Spaces:
Runtime error
Runtime error
from transformers import AutoModel, AutoTokenizer | |
import faiss | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import torch | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
def load_model_and_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert") | |
model = AutoModel.from_pretrained("kaisugi/scitoricsbert") | |
model.eval() | |
return model, tokenizer | |
def load_sentence_data(): | |
sentence_df = pd.read_csv("sentence_data_789k.csv.gz") | |
return sentence_df | |
def load_sentence_embeddings_and_index(): | |
npz_comp = np.load("sentence_embeddings_789k.npz") | |
sentence_embeddings = npz_comp["arr_0"] | |
faiss.normalize_L2(sentence_embeddings) | |
D = 768 | |
N = 789188 | |
Xt = sentence_embeddings[:39000] | |
X = sentence_embeddings | |
# Param of PQ | |
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc. | |
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte | |
# Param of IVF | |
nlist = 888 # The number of cells (space partition). Typical value is sqrt(N) | |
# Param of HNSW | |
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32 | |
# Setup | |
quantizer = faiss.IndexHNSWFlat(D, hnsw_m) | |
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits) | |
# Train | |
index.train(Xt) | |
# Add | |
index.add(X) | |
# Search | |
index.nprobe = 8 # Runtime param. The number of cells that are visited for search. | |
return sentence_embeddings, index | |
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df): | |
with torch.no_grad(): | |
inputs = tokenizer.encode_plus( | |
input_text, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors='pt' | |
) | |
outputs = model(**inputs) | |
query_embeddings = outputs.last_hidden_state[:, 0, :][0] | |
query_embeddings = query_embeddings.detach().cpu().numpy() | |
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2) | |
_, ids = index.search(x=np.array([query_embeddings]), k=top_k) | |
retrieved_sentences = [] | |
for id in ids[0]: | |
retrieved_sentences.append(sentence_df.loc[id, "sentence"]) | |
return pd.DataFrame({"sentences": retrieved_sentences}) | |
if __name__ == "__main__": | |
model, tokenizer = load_model_and_tokenizer() | |
sentence_df = load_sentence_data() | |
sentence_embeddings, index = load_sentence_embeddings_and_index() | |
st.markdown("## AI-based Paraphrasing for Academic Writing") | |
input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...") | |
top_k = st.number_input('top_k', min_value=1, value=10, step=1) | |
if st.button('search'): | |
df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df) | |
st.table(df) |