Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
@@ -65,7 +65,7 @@ def build_faiss_index(sentence_emeddings):
|
|
65 |
return index
|
66 |
|
67 |
|
68 |
-
@st.cache
|
69 |
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
|
70 |
with torch.no_grad():
|
71 |
inputs = tokenizer.encode_plus(
|
@@ -80,20 +80,23 @@ def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_d
|
|
80 |
query_embeddings = query_embeddings.detach().cpu().numpy()
|
81 |
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
|
82 |
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
print(dists)
|
87 |
-
print(ids)
|
88 |
|
89 |
|
90 |
-
def main(model, tokenizer, sentence_df,
|
91 |
st.markdown("## AI-based Paraphrasing for Academic Writing")
|
92 |
|
93 |
input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
|
94 |
top_k = st.number_input('top_k', min_value=1, value=10, step=1)
|
95 |
|
96 |
-
get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
|
|
|
97 |
|
98 |
|
99 |
if __name__ == "__main__":
|
@@ -104,4 +107,4 @@ if __name__ == "__main__":
|
|
104 |
faiss.normalize_L2(sentence_emeddings)
|
105 |
index = build_faiss_index(sentence_emeddings)
|
106 |
|
107 |
-
main(model, tokenizer, sentence_df,
|
|
|
65 |
return index
|
66 |
|
67 |
|
68 |
+
@st.cache(allow_output_mutation=True)
|
69 |
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
|
70 |
with torch.no_grad():
|
71 |
inputs = tokenizer.encode_plus(
|
|
|
80 |
query_embeddings = query_embeddings.detach().cpu().numpy()
|
81 |
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
|
82 |
|
83 |
+
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
|
84 |
+
retrieved_sentences = []
|
85 |
+
|
86 |
+
for id in ids[0]:
|
87 |
+
retrieved_sentences.append(sentence_df.loc[id, "sentence"])
|
88 |
|
89 |
+
return pd.DataFrame({"sentences": retrieved_sentences})
|
|
|
|
|
90 |
|
91 |
|
92 |
+
def main(model, tokenizer, sentence_df, index):
|
93 |
st.markdown("## AI-based Paraphrasing for Academic Writing")
|
94 |
|
95 |
input_text = st.text_area("text input", "Model have good results.", placeholder="Write something here...")
|
96 |
top_k = st.number_input('top_k', min_value=1, value=10, step=1)
|
97 |
|
98 |
+
df = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df)
|
99 |
+
st.table(df)
|
100 |
|
101 |
|
102 |
if __name__ == "__main__":
|
|
|
107 |
faiss.normalize_L2(sentence_emeddings)
|
108 |
index = build_faiss_index(sentence_emeddings)
|
109 |
|
110 |
+
main(model, tokenizer, sentence_df, index)
|