Spaces:
Runtime error
Runtime error
File size: 3,665 Bytes
b6363d9 e8c441c e1a3f25 e8c441c e1a3f25 e8c441c e1a3f25 e8c441c e1a3f25 7db6000 e8c441c 05f1914 7db6000 05f1914 7db6000 e8c441c 7db6000 e8c441c e1a3f25 b6363d9 e1a3f25 b6363d9 7db6000 e8c441c e1a3f25 7db6000 e1a3f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from transformers import AutoModel, AutoTokenizer
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
model = AutoModel.from_pretrained("kaisugi/scitoricsbert")
model.eval()
return model, tokenizer
@st.cache(allow_output_mutation=True)
def load_sentence_data():
sentence_df = pd.read_csv("sentence_data_789k.csv.gz")
return sentence_df
@st.cache(allow_output_mutation=True)
def load_sentence_embeddings_and_index():
npz_comp = np.load("sentence_embeddings_789k.npz")
sentence_embeddings = npz_comp["arr_0"]
faiss.normalize_L2(sentence_embeddings)
D = 768
N = 789188
Xt = sentence_embeddings[:39000]
X = sentence_embeddings
# Param of PQ
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
# Param of IVF
nlist = 888 # The number of cells (space partition). Typical value is sqrt(N)
# Param of HNSW
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
# Setup
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
# Train
index.train(Xt)
# Add
index.add(X)
# Search
index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
return sentence_embeddings, index
@st.cache(allow_output_mutation=True)
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list):
with torch.no_grad():
inputs = tokenizer.encode_plus(
input_text,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
outputs = model(**inputs)
query_embeddings = outputs.last_hidden_state[:, 0, :][0]
query_embeddings = query_embeddings.detach().cpu().numpy()
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
retrieved_sentences = []
retrieved_paper_id = []
for id in ids[0]:
retrieved_sentences.append(sentence_df.loc[id, "sentence"])
retrieved_paper_id.append(f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}")
all_df = pd.DataFrame({"sentence": retrieved_sentences, "source link": retrieved_paper_id})
if len(exclude_word_list) == 0:
return all_df
else:
exclude_word_list_regex = '|'.join(exclude_word_list)
return all_df[~all_df["sentence"].str.contains(exclude_word_list_regex)]
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer()
sentence_df = load_sentence_data()
sentence_embeddings, index = load_sentence_embeddings_and_index()
st.markdown("## AI-based Paraphrasing for Academic Writing")
input_text = st.text_area("text input", "We saw difference in the results between A and B.", placeholder="Write something here...")
top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
input_words = st.text_input("exclude words (comma separated)", "see, saw")
if st.button('search'):
exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list)
st.table(df) |