kaisugi commited on
Commit
05f1914
·
1 Parent(s): e8c441c
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -65,7 +65,7 @@ def build_faiss_index(sentence_emeddings):
65
  return index
66
 
67
 
68
- @st.cache
69
  def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
70
  with torch.no_grad():
71
  inputs = tokenizer.encode_plus(
@@ -80,20 +80,23 @@ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_d
80
  query_embeddings = query_embeddings.detach().cpu().numpy()
81
  query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
82
 
83
- print(np.array([query_embeddings]))
 
 
 
 
84
 
85
- dists, ids = index.search(x=np.array([query_embeddings]), k=top_k)
86
- print(dists)
87
- print(ids)
88
 
89
 
90
- def main(model, tokenizer, sentence_df, sentence_embeddings, index):
91
  st.markdown("## AI-based Paraphrasing for Academic Writing")
92
 
93
  input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
94
  top_k = st.number_input('top_k', min_value=1, value=10, step=1)
95
 
96
- get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
 
97
 
98
 
99
  if __name__ == "__main__":
@@ -104,4 +107,4 @@ if __name__ == "__main__":
104
  faiss.normalize_L2(sentence_emeddings)
105
  index = build_faiss_index(sentence_emeddings)
106
 
107
- main(model, tokenizer, sentence_df, sentence_emeddings, index)
 
65
  return index
66
 
67
 
68
+ @st.cache(allow_output_mutation=True)
69
  def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
70
  with torch.no_grad():
71
  inputs = tokenizer.encode_plus(
 
80
  query_embeddings = query_embeddings.detach().cpu().numpy()
81
  query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
82
 
83
+ _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
84
+ retrieved_sentences = []
85
+
86
+ for id in ids[0]:
87
+ retrieved_sentences.append(sentence_df.loc[id, "sentence"])
88
 
89
+ return pd.DataFrame({"sentences": retrieved_sentences})
 
 
90
 
91
 
92
+ def main(model, tokenizer, sentence_df, index):
93
  st.markdown("## AI-based Paraphrasing for Academic Writing")
94
 
95
  input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
96
  top_k = st.number_input('top_k', min_value=1, value=10, step=1)
97
 
98
+ df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
99
+ st.table(df)
100
 
101
 
102
  if __name__ == "__main__":
 
107
  faiss.normalize_L2(sentence_emeddings)
108
  index = build_faiss_index(sentence_emeddings)
109
 
110
+ main(model, tokenizer, sentence_df, index)