Spaces:
Runtime error
Runtime error
File size: 6,516 Bytes
b6363d9 e8c441c 60689d7 00eee05 e8c441c 00eee05 e8c441c 60689d7 e8c441c e1a3f25 60689d7 e8c441c e1a3f25 60689d7 e1a3f25 60689d7 e1a3f25 e8c441c e1a3f25 e8c441c e1a3f25 00eee05 e8c441c 05f1914 00eee05 05f1914 00eee05 e8c441c 00eee05 7db6000 00eee05 e8c441c e1a3f25 b6363d9 e1a3f25 b6363d9 00eee05 e8c441c e1a3f25 7db6000 00eee05 de40660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from transformers import AutoModel, AutoTokenizer
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
import math
import os
import re
os.environ['KMP_DUPLICATE_LIB_OK']='True'
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
model = AutoModel.from_pretrained("kaisugi/scitoricsbert", output_attentions=True)
model.eval()
return model, tokenizer
@st.cache(allow_output_mutation=True)
def load_sentence_data():
sentence_df = pd.read_csv("sentence_data_858k.csv.gz")
return sentence_df
@st.cache(allow_output_mutation=True)
def load_sentence_embeddings_and_index():
npz_comp = np.load("sentence_embeddings_858k.npz")
sentence_embeddings = npz_comp["arr_0"]
faiss.normalize_L2(sentence_embeddings)
D = 768
N = 857610
Xt = sentence_embeddings[:100000]
X = sentence_embeddings
# Param of PQ
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
# Param of IVF
nlist = int(math.sqrt(N)) # The number of cells (space partition). Typical value is sqrt(N)
# Param of HNSW
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
# Setup
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
# Train
index.train(Xt)
# Add
index.add(X)
# Search
index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
return sentence_embeddings, index
@st.cache(allow_output_mutation=True)
def formulaic_phrase_extraction(sentences, model, tokenizer):
THRESHOLD = 0.01
LAYER = 10
output_sentences = []
with torch.no_grad():
inputs = tokenizer.batch_encode_plus(
sentences,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
outputs = model(**inputs)
attention = outputs[-1]
cls_attentions = torch.mean(attention[LAYER][0], dim=0)
for sentence, cls_attention in zip(sentences, cls_attentions):
check_bool_arr = list((cls_attention > THRESHOLD).numpy())[1:-1]
tokens = tokenizer.tokenize(sentence)
cur_tokens = tokens.copy()
while True:
flg = False
for idx, token in enumerate(cur_tokens):
if token.startswith("##"):
flg = True
back_token = token.replace("##", "")
front_token = cur_tokens.pop(idx-1)
cur_tokens[idx-1] = front_token + back_token
back_bool_val = check_bool_arr[idx]
front_bool_val = check_bool_arr.pop(idx-1)
check_bool_arr[idx-1] = front_bool_val and back_bool_val
if not flg:
break
result = " ".join([f'<font color="coral">{original_word}</font>' if b else original_word for (b, original_word) in zip(check_bool_arr, sentence.split())])
output_sentences.append(result)
return output_sentences
@st.cache(allow_output_mutation=True)
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=True):
with torch.no_grad():
inputs = tokenizer.encode_plus(
input_text,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
outputs = model(**inputs)
query_embeddings = outputs.last_hidden_state[:, 0, :][0]
query_embeddings = query_embeddings.detach().cpu().numpy()
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
retrieved_sentences = []
retrieved_paper_ids = []
for id in ids[0]:
cur_sentence = sentence_df.loc[id, "sentence"]
cur_link = f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}"
if len(exclude_word_list) == 0:
retrieved_sentences.append(cur_sentence)
retrieved_paper_ids.append(cur_link)
else:
exclude_word_list_regex = '|'.join(exclude_word_list)
pat = re.compile(f'{exclude_word_list_regex}')
if not bool(pat.search(cur_sentence)):
retrieved_sentences.append(cur_sentence)
retrieved_paper_ids.append(cur_link)
if phrase_annotated:
retrieved_sentences = formulaic_phrase_extraction(retrieved_sentences, model, tokenizer)
return retrieved_sentences, retrieved_paper_ids
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer()
sentence_df = load_sentence_data()
sentence_embeddings, index = load_sentence_embeddings_and_index()
st.markdown("## AI-based Paraphrasing for Academic Writing")
input_text = st.text_area("text input", "Our model shows good results.", placeholder="Write something here...")
top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
input_words = st.text_input("exclude words (comma separated)", "good, result")
agree = st.checkbox('Include phrase annotation')
if st.button('search'):
exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
retrieved_sentences, retrieved_paper_ids = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=agree)
result_table_markdown = "| sentence | source link |\n|:---|:---|\n"
for (retrieved_sentence, retrieved_paper_id) in zip(retrieved_sentences, retrieved_paper_ids):
result_table_markdown += f"| {retrieved_sentence} | {retrieved_paper_id} |\n"
st.markdown(result_table_markdown, unsafe_allow_html=True)
st.markdown("---\n#### How this works")
st.markdown("This app uses ScitoricsBERT [(Sugimoto and Aizawa, 2022)](https://aclanthology.org/2022.sdp-1.7/), a functional sentence representation model, to retrieve sentences that are functionally similar to the input. It also extracts phrasal patterns that accord to the function, by leveraging self-attention patterns within ScitoricsBERT.") |