Spaces:
Runtime error
Runtime error
File size: 3,693 Bytes
b6363d9 e8c441c 60689d7 e8c441c 60689d7 e8c441c e1a3f25 60689d7 e8c441c e1a3f25 60689d7 e1a3f25 60689d7 e1a3f25 e8c441c e1a3f25 e8c441c e1a3f25 7db6000 e8c441c 05f1914 7db6000 05f1914 7db6000 e8c441c 7db6000 e8c441c e1a3f25 b6363d9 e1a3f25 b6363d9 7db6000 60689d7 7db6000 e8c441c e1a3f25 7db6000 e1a3f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from transformers import AutoModel, AutoTokenizer
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
import math
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
model = AutoModel.from_pretrained("kaisugi/scitoricsbert")
model.eval()
return model, tokenizer
@st.cache(allow_output_mutation=True)
def load_sentence_data():
sentence_df = pd.read_csv("sentence_data_858k.csv.gz")
return sentence_df
@st.cache(allow_output_mutation=True)
def load_sentence_embeddings_and_index():
npz_comp = np.load("sentence_embeddings_858k.npz")
sentence_embeddings = npz_comp["arr_0"]
faiss.normalize_L2(sentence_embeddings)
D = 768
N = 857610
Xt = sentence_embeddings[:100000]
X = sentence_embeddings
# Param of PQ
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
# Param of IVF
nlist = int(math.sqrt(N)) # The number of cells (space partition). Typical value is sqrt(N)
# Param of HNSW
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
# Setup
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
# Train
index.train(Xt)
# Add
index.add(X)
# Search
index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
return sentence_embeddings, index
@st.cache(allow_output_mutation=True)
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list):
with torch.no_grad():
inputs = tokenizer.encode_plus(
input_text,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
outputs = model(**inputs)
query_embeddings = outputs.last_hidden_state[:, 0, :][0]
query_embeddings = query_embeddings.detach().cpu().numpy()
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
retrieved_sentences = []
retrieved_paper_id = []
for id in ids[0]:
retrieved_sentences.append(sentence_df.loc[id, "sentence"])
retrieved_paper_id.append(f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}")
all_df = pd.DataFrame({"sentence": retrieved_sentences, "source link": retrieved_paper_id})
if len(exclude_word_list) == 0:
return all_df
else:
exclude_word_list_regex = '|'.join(exclude_word_list)
return all_df[~all_df["sentence"].str.contains(exclude_word_list_regex)]
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer()
sentence_df = load_sentence_data()
sentence_embeddings, index = load_sentence_embeddings_and_index()
st.markdown("## AI-based Paraphrasing for Academic Writing")
input_text = st.text_area("text input", "We saw difference in the results between A and B.", placeholder="Write something here...")
top_k = st.number_input('top_k (upperbound)', min_value=1, value=200, step=1)
input_words = st.text_input("exclude words (comma separated)", "see, saw")
if st.button('search'):
exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list)
st.table(df) |