Spaces:
Runtime error
Runtime error
| from transformers import AutoModel, AutoTokenizer | |
| import faiss | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| import math | |
| import os | |
| os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
| def load_model_and_tokenizer(): | |
| tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert") | |
| model = AutoModel.from_pretrained("kaisugi/scitoricsbert") | |
| model.eval() | |
| return model, tokenizer | |
| def load_sentence_data(): | |
| sentence_df = pd.read_csv("sentence_data_858k.csv.gz") | |
| return sentence_df | |
| def load_sentence_embeddings_and_index(): | |
| npz_comp = np.load("sentence_embeddings_858k.npz") | |
| sentence_embeddings = npz_comp["arr_0"] | |
| faiss.normalize_L2(sentence_embeddings) | |
| D = 768 | |
| N = 857610 | |
| Xt = sentence_embeddings[:100000] | |
| X = sentence_embeddings | |
| # Param of PQ | |
| M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc. | |
| nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte | |
| # Param of IVF | |
| nlist = int(math.sqrt(N)) # The number of cells (space partition). Typical value is sqrt(N) | |
| # Param of HNSW | |
| hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32 | |
| # Setup | |
| quantizer = faiss.IndexHNSWFlat(D, hnsw_m) | |
| index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits) | |
| # Train | |
| index.train(Xt) | |
| # Add | |
| index.add(X) | |
| # Search | |
| index.nprobe = 8 # Runtime param. The number of cells that are visited for search. | |
| return sentence_embeddings, index | |
| def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list): | |
| with torch.no_grad(): | |
| inputs = tokenizer.encode_plus( | |
| input_text, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors='pt' | |
| ) | |
| outputs = model(**inputs) | |
| query_embeddings = outputs.last_hidden_state[:, 0, :][0] | |
| query_embeddings = query_embeddings.detach().cpu().numpy() | |
| query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2) | |
| _, ids = index.search(x=np.array([query_embeddings]), k=top_k) | |
| retrieved_sentences = [] | |
| retrieved_paper_id = [] | |
| for id in ids[0]: | |
| retrieved_sentences.append(sentence_df.loc[id, "sentence"]) | |
| retrieved_paper_id.append(f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}") | |
| all_df = pd.DataFrame({"sentence": retrieved_sentences, "source link": retrieved_paper_id}) | |
| if len(exclude_word_list) == 0: | |
| return all_df | |
| else: | |
| exclude_word_list_regex = '|'.join(exclude_word_list) | |
| return all_df[~all_df["sentence"].str.contains(exclude_word_list_regex)] | |
| if __name__ == "__main__": | |
| model, tokenizer = load_model_and_tokenizer() | |
| sentence_df = load_sentence_data() | |
| sentence_embeddings, index = load_sentence_embeddings_and_index() | |
| st.markdown("## AI-based Paraphrasing for Academic Writing") | |
| input_text = st.text_area("text input", "We saw difference in the results between A and B.", placeholder="Write something here...") | |
| top_k = st.number_input('top_k (upperbound)', min_value=1, value=200, step=1) | |
| input_words = st.text_input("exclude words (comma separated)", "see, saw") | |
| if st.button('search'): | |
| exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""] | |
| df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list) | |
| st.table(df) |