import faiss import numpy as np import pandas as pd import streamlit as st import torch from transformers import AutoModel, AutoTokenizer import os os.environ['KMP_DUPLICATE_LIB_OK']='True' @st.cache(allow_output_mutation=True) def load_model_and_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert") model = AutoModel.from_pretrained("kaisugi/scitoricsbert") model.eval() return model, tokenizer @st.cache(allow_output_mutation=True) def load_sentence_data(): sentence_df = pd.read_csv("sentence_data_789k.csv.gz") return sentence_df @st.cache(allow_output_mutation=True) def load_sentence_embeddings(): npz_comp = np.load("sentence_embeddings_789k.npz") sentence_embeddings = npz_comp["arr_0"] return sentence_embeddings @st.cache(allow_output_mutation=True) def build_faiss_index(sentence_emeddings): D = 768 N = 789188 Xt = sentence_emeddings[:39000] X = sentence_emeddings # Param of PQ M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc. nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte # Param of IVF nlist = 1000 # The number of cells (space partition). Typical value is sqrt(N) # Param of HNSW hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32 # Setup quantizer = faiss.IndexHNSWFlat(D, hnsw_m) index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits) # Train index.train(Xt) # Add index.add(X) # Search index.nprobe = 8 # Runtime param. The number of cells that are visited for search. return index @st.cache def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df): with torch.no_grad(): inputs = tokenizer.encode_plus( input_text, padding=True, truncation=True, max_length=512, return_tensors='pt' ) outputs = model(**inputs) query_embeddings = outputs.last_hidden_state[:, 0, :][0] query_embeddings = query_embeddings.detach().cpu().numpy() query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2) print(np.array([query_embeddings])) dists, ids = index.search(x=np.array([query_embeddings]), k=top_k) print(dists) print(ids) def main(model, tokenizer, sentence_df, sentence_embeddings, index): st.markdown("## AI-based Paraphrasing for Academic Writing") input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...") top_k = st.number_input('top_k', min_value=1, value=10, step=1) get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df) if __name__ == "__main__": model, tokenizer = load_model_and_tokenizer() sentence_df = load_sentence_data() sentence_emeddings = load_sentence_embeddings() faiss.normalize_L2(sentence_emeddings) index = build_faiss_index(sentence_emeddings) main(model, tokenizer, sentence_df, sentence_emeddings, index)