kaisugi commited on
Commit
e8c441c
·
1 Parent(s): 566dd98
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ import pandas as pd
4
+ import streamlit as st
5
+ import torch
6
+ from transformers import AutoModel, AutoTokenizer
7
+
8
+ import os
9
+
10
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
11
+
12
+
13
+ @st.cache(allow_output_mutation=True)
14
+ def load_model_and_tokenizer():
15
+ tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
16
+ model = AutoModel.from_pretrained("kaisugi/scitoricsbert")
17
+ model.eval()
18
+
19
+ return model, tokenizer
20
+
21
+
22
+ @st.cache(allow_output_mutation=True)
23
+ def load_sentence_data():
24
+ sentence_df = pd.read_csv("sentence_data_789k.csv.gz")
25
+
26
+ return sentence_df
27
+
28
+
29
+ @st.cache(allow_output_mutation=True)
30
+ def load_sentence_embeddings():
31
+ npz_comp = np.load("sentence_embeddings_789k.npz")
32
+ sentence_embeddings = npz_comp["arr_0"]
33
+
34
+ return sentence_embeddings
35
+
36
+
37
+ @st.cache(allow_output_mutation=True)
38
+ def build_faiss_index(sentence_emeddings):
39
+ D = 768
40
+ N = 789188
41
+ Xt = sentence_emeddings[:39000]
42
+ X = sentence_emeddings
43
+
44
+ # Param of PQ
45
+ M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
46
+ nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
47
+ # Param of IVF
48
+ nlist = 1000 # The number of cells (space partition). Typical value is sqrt(N)
49
+ # Param of HNSW
50
+ hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
51
+
52
+ # Setup
53
+ quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
54
+ index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
55
+
56
+ # Train
57
+ index.train(Xt)
58
+
59
+ # Add
60
+ index.add(X)
61
+
62
+ # Search
63
+ index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
64
+
65
+ return index
66
+
67
+
68
+ @st.cache
69
+ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
70
+ with torch.no_grad():
71
+ inputs = tokenizer.encode_plus(
72
+ input_text,
73
+ padding=True,
74
+ truncation=True,
75
+ max_length=512,
76
+ return_tensors='pt'
77
+ )
78
+ outputs = model(**inputs)
79
+ query_embeddings = outputs.last_hidden_state[:, 0, :][0]
80
+ query_embeddings = query_embeddings.detach().cpu().numpy()
81
+ query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
82
+
83
+ print(np.array([query_embeddings]))
84
+
85
+ dists, ids = index.search(x=np.array([query_embeddings]), k=top_k)
86
+ print(dists)
87
+ print(ids)
88
+
89
+
90
+ def main(model, tokenizer, sentence_df, sentence_embeddings, index):
91
+ st.markdown("## AI-based Paraphrasing for Academic Writing")
92
+
93
+ input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
94
+ top_k = st.number_input('top_k', min_value=1, value=10, step=1)
95
+
96
+ get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ model, tokenizer = load_model_and_tokenizer()
101
+ sentence_df = load_sentence_data()
102
+ sentence_emeddings = load_sentence_embeddings()
103
+
104
+ faiss.normalize_L2(sentence_emeddings)
105
+ index = build_faiss_index(sentence_emeddings)
106
+
107
+ main(model, tokenizer, sentence_df, sentence_emeddings, index)