File size: 3,221 Bytes
e8c441c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05f1914
e8c441c
 
 
 
 
 
 
 
 
 
 
 
 
 
05f1914
 
 
 
 
e8c441c
05f1914
e8c441c
 
05f1914
e8c441c
 
 
 
 
05f1914
 
e8c441c
 
 
 
 
 
 
 
 
 
05f1914
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
from transformers import AutoModel, AutoTokenizer

import os 

os.environ['KMP_DUPLICATE_LIB_OK']='True'


@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
    model = AutoModel.from_pretrained("kaisugi/scitoricsbert")
    model.eval()
    
    return model, tokenizer


@st.cache(allow_output_mutation=True)
def load_sentence_data():
    sentence_df = pd.read_csv("sentence_data_789k.csv.gz")
    
    return sentence_df


@st.cache(allow_output_mutation=True)
def load_sentence_embeddings():
    npz_comp = np.load("sentence_embeddings_789k.npz")
    sentence_embeddings = npz_comp["arr_0"]

    return sentence_embeddings


@st.cache(allow_output_mutation=True)
def build_faiss_index(sentence_emeddings):
    D = 768
    N = 789188
    Xt = sentence_emeddings[:39000]
    X = sentence_emeddings

    # Param of PQ
    M = 16  # The number of sub-vector. Typically this is 8, 16, 32, etc.
    nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
    # Param of IVF
    nlist = 1000  # The number of cells (space partition). Typical value is sqrt(N)
    # Param of HNSW
    hnsw_m = 32  # The number of neighbors for HNSW. This is typically 32

    # Setup
    quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
    index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)

    # Train
    index.train(Xt)

    # Add
    index.add(X)

    # Search
    index.nprobe = 8  # Runtime param. The number of cells that are visited for search.

    return index


@st.cache(allow_output_mutation=True)
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
    with torch.no_grad():
        inputs = tokenizer.encode_plus(
            input_text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        outputs = model(**inputs)
        query_embeddings = outputs.last_hidden_state[:, 0, :][0]
        query_embeddings = query_embeddings.detach().cpu().numpy()
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)

    _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
    retrieved_sentences = []

    for id in ids[0]:
        retrieved_sentences.append(sentence_df.loc[id, "sentence"])

    return pd.DataFrame({"sentences": retrieved_sentences})


def main(model, tokenizer, sentence_df, index):
    st.markdown("## AI-based Paraphrasing for Academic Writing")

    input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
    top_k = st.number_input('top_k', min_value=1, value=10, step=1)

    df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
    st.table(df)


if __name__ == "__main__":
    model, tokenizer = load_model_and_tokenizer()
    sentence_df = load_sentence_data()
    sentence_emeddings = load_sentence_embeddings()

    faiss.normalize_L2(sentence_emeddings)
    index = build_faiss_index(sentence_emeddings)

    main(model, tokenizer, sentence_df, index)