File size: 6,516 Bytes
b6363d9
e8c441c
 
 
 
 
 
60689d7
00eee05
 
e8c441c
 
 
 
 
 
 
00eee05
e8c441c
 
 
 
 
 
 
60689d7
e8c441c
 
 
 
 
e1a3f25
60689d7
e8c441c
 
e1a3f25
 
60689d7
 
e1a3f25
 
 
 
 
 
60689d7
e1a3f25
 
e8c441c
e1a3f25
 
 
e8c441c
e1a3f25
 
 
 
 
 
 
 
 
 
 
 
 
00eee05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8c441c
 
 
 
 
 
 
 
 
 
 
 
 
05f1914
 
00eee05
05f1914
 
00eee05
 
e8c441c
00eee05
 
 
7db6000
00eee05
 
 
 
 
 
 
 
 
 
 
 
e8c441c
 
 
 
 
e1a3f25
b6363d9
 
e1a3f25
b6363d9
00eee05
 
 
 
 
e8c441c
e1a3f25
7db6000
00eee05
 
 
 
 
 
 
 
 
 
 
de40660
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from transformers import AutoModel, AutoTokenizer
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch

import math
import os
import re

os.environ['KMP_DUPLICATE_LIB_OK']='True'


@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
    model = AutoModel.from_pretrained("kaisugi/scitoricsbert", output_attentions=True)
    model.eval()
    
    return model, tokenizer


@st.cache(allow_output_mutation=True)
def load_sentence_data():
    sentence_df = pd.read_csv("sentence_data_858k.csv.gz")
    
    return sentence_df


@st.cache(allow_output_mutation=True)
def load_sentence_embeddings_and_index():
    npz_comp = np.load("sentence_embeddings_858k.npz")
    sentence_embeddings = npz_comp["arr_0"]

    faiss.normalize_L2(sentence_embeddings)
    D = 768
    N = 857610
    Xt = sentence_embeddings[:100000]
    X = sentence_embeddings

    # Param of PQ
    M = 16  # The number of sub-vector. Typically this is 8, 16, 32, etc.
    nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
    # Param of IVF
    nlist = int(math.sqrt(N))  # The number of cells (space partition). Typical value is sqrt(N)
    # Param of HNSW
    hnsw_m = 32  # The number of neighbors for HNSW. This is typically 32

    # Setup
    quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
    index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)

    # Train
    index.train(Xt)

    # Add
    index.add(X)

    # Search
    index.nprobe = 8  # Runtime param. The number of cells that are visited for search.

    return sentence_embeddings, index


@st.cache(allow_output_mutation=True)
def formulaic_phrase_extraction(sentences, model, tokenizer):
    THRESHOLD = 0.01
    LAYER = 10

    output_sentences = []

    with torch.no_grad():
        inputs = tokenizer.batch_encode_plus(
            sentences, 
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        outputs = model(**inputs)
        attention = outputs[-1]

        cls_attentions = torch.mean(attention[LAYER][0], dim=0)

        for sentence, cls_attention in zip(sentences, cls_attentions):
            check_bool_arr = list((cls_attention > THRESHOLD).numpy())[1:-1]
            tokens = tokenizer.tokenize(sentence)

            cur_tokens = tokens.copy()

            while True:
                flg = False

                for idx, token in enumerate(cur_tokens):
                    if token.startswith("##"):
                        flg = True
                        back_token = token.replace("##", "")
                        front_token = cur_tokens.pop(idx-1)
                        cur_tokens[idx-1] = front_token + back_token

                        back_bool_val = check_bool_arr[idx]
                        front_bool_val = check_bool_arr.pop(idx-1)
                        check_bool_arr[idx-1] = front_bool_val and back_bool_val

                if not flg:
                    break

            result = " ".join([f'<font color="coral">{original_word}</font>' if b else original_word for (b, original_word) in zip(check_bool_arr, sentence.split())])
            output_sentences.append(result)

    return output_sentences


@st.cache(allow_output_mutation=True)
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=True):
    with torch.no_grad():
        inputs = tokenizer.encode_plus(
            input_text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        outputs = model(**inputs)
        query_embeddings = outputs.last_hidden_state[:, 0, :][0]
        query_embeddings = query_embeddings.detach().cpu().numpy()
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)

    _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
    retrieved_sentences = []
    retrieved_paper_ids = []

    for id in ids[0]:
        cur_sentence = sentence_df.loc[id, "sentence"]
        cur_link = f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}"

        if len(exclude_word_list) == 0:
            retrieved_sentences.append(cur_sentence)
            retrieved_paper_ids.append(cur_link)

        else:
            exclude_word_list_regex = '|'.join(exclude_word_list)
            pat = re.compile(f'{exclude_word_list_regex}')
            
            if not bool(pat.search(cur_sentence)):
                retrieved_sentences.append(cur_sentence)
                retrieved_paper_ids.append(cur_link)

    if phrase_annotated:
        retrieved_sentences = formulaic_phrase_extraction(retrieved_sentences, model, tokenizer)

    return retrieved_sentences, retrieved_paper_ids


if __name__ == "__main__":
    model, tokenizer = load_model_and_tokenizer()
    sentence_df = load_sentence_data()
    sentence_embeddings, index = load_sentence_embeddings_and_index()


    st.markdown("## AI-based Paraphrasing for Academic Writing")

    input_text = st.text_area("text input", "Our model shows good results.", placeholder="Write something here...")
    top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
    input_words = st.text_input("exclude words (comma separated)", "good, result")

    agree = st.checkbox('Include phrase annotation')

    if st.button('search'):
        exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
        retrieved_sentences, retrieved_paper_ids = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=agree)

        result_table_markdown = "|  sentence  |  source link  |\n|:---|:---|\n"

        for (retrieved_sentence, retrieved_paper_id) in zip(retrieved_sentences, retrieved_paper_ids):
            result_table_markdown += f"| {retrieved_sentence} | {retrieved_paper_id} |\n"
        
        st.markdown(result_table_markdown, unsafe_allow_html=True)

    st.markdown("---\n#### How this works")

    st.markdown("This app uses ScitoricsBERT [(Sugimoto and Aizawa, 2022)](https://aclanthology.org/2022.sdp-1.7/), a functional sentence representation model, to retrieve sentences that are functionally similar to the input. It also extracts phrasal patterns that accord to the function, by leveraging self-attention patterns within ScitoricsBERT.")