File size: 3,665 Bytes
b6363d9
e8c441c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1a3f25
e8c441c
 
 
e1a3f25
 
 
 
 
 
 
 
 
 
 
 
 
e8c441c
e1a3f25
 
 
e8c441c
e1a3f25
 
 
 
 
 
 
 
 
 
 
 
 
7db6000
e8c441c
 
 
 
 
 
 
 
 
 
 
 
 
05f1914
 
7db6000
05f1914
 
 
7db6000
e8c441c
7db6000
 
 
 
 
 
 
e8c441c
 
 
 
 
e1a3f25
b6363d9
 
e1a3f25
b6363d9
7db6000
 
 
e8c441c
e1a3f25
7db6000
 
e1a3f25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from transformers import AutoModel, AutoTokenizer
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch

import os 

os.environ['KMP_DUPLICATE_LIB_OK']='True'


@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
    model = AutoModel.from_pretrained("kaisugi/scitoricsbert")
    model.eval()
    
    return model, tokenizer


@st.cache(allow_output_mutation=True)
def load_sentence_data():
    sentence_df = pd.read_csv("sentence_data_789k.csv.gz")
    
    return sentence_df


@st.cache(allow_output_mutation=True)
def load_sentence_embeddings_and_index():
    npz_comp = np.load("sentence_embeddings_789k.npz")
    sentence_embeddings = npz_comp["arr_0"]

    faiss.normalize_L2(sentence_embeddings)
    D = 768
    N = 789188
    Xt = sentence_embeddings[:39000]
    X = sentence_embeddings

    # Param of PQ
    M = 16  # The number of sub-vector. Typically this is 8, 16, 32, etc.
    nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
    # Param of IVF
    nlist = 888  # The number of cells (space partition). Typical value is sqrt(N)
    # Param of HNSW
    hnsw_m = 32  # The number of neighbors for HNSW. This is typically 32

    # Setup
    quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
    index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)

    # Train
    index.train(Xt)

    # Add
    index.add(X)

    # Search
    index.nprobe = 8  # Runtime param. The number of cells that are visited for search.

    return sentence_embeddings, index


@st.cache(allow_output_mutation=True)
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list):
    with torch.no_grad():
        inputs = tokenizer.encode_plus(
            input_text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        outputs = model(**inputs)
        query_embeddings = outputs.last_hidden_state[:, 0, :][0]
        query_embeddings = query_embeddings.detach().cpu().numpy()
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)

    _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
    retrieved_sentences = []
    retrieved_paper_id = []

    for id in ids[0]:
        retrieved_sentences.append(sentence_df.loc[id, "sentence"])
        retrieved_paper_id.append(f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}")

    all_df = pd.DataFrame({"sentence": retrieved_sentences, "source link": retrieved_paper_id})

    if len(exclude_word_list) == 0:
        return all_df
    else:
        exclude_word_list_regex = '|'.join(exclude_word_list)
        return all_df[~all_df["sentence"].str.contains(exclude_word_list_regex)]


if __name__ == "__main__":
    model, tokenizer = load_model_and_tokenizer()
    sentence_df = load_sentence_data()
    sentence_embeddings, index = load_sentence_embeddings_and_index()


    st.markdown("## AI-based Paraphrasing for Academic Writing")

    input_text = st.text_area("text input", "We saw difference in the results between A and B.", placeholder="Write something here...")
    top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
    input_words = st.text_input("exclude words (comma separated)", "see, saw")

    if st.button('search'):
        exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
        df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list)
        st.table(df)