ryota39 commited on
Commit
6a6ca54
·
verified ·
1 Parent(s): c98fb1f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import streamlit as st
4
+ from langchain_community.vectorstores.faiss import FAISS
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+
7
+
8
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
9
+
10
+
11
+ @st.cache_resource
12
+ def create_vector_store(
13
+ vector_store_path: str,
14
+ embedding_model_name: str,
15
+ ) -> FAISS:
16
+ embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name)
17
+ vector_store = FAISS.load_local(
18
+ folder_path=vector_store_path,
19
+ embeddings=embedding_model,
20
+ allow_dangerous_deserialization=True,
21
+ )
22
+ return vector_store
23
+
24
+
25
+ def grab_topk(
26
+ input_text: str,
27
+ vector_store: FAISS,
28
+ top_k: int,
29
+ ) -> pd.DataFrame:
30
+ retriever = vector_store.as_retriever(search_kwargs={"k": top_k + 1})
31
+ relevant_docs = retriever.get_relevant_documents(input_text)
32
+
33
+ abstracts = list()
34
+ titles = list()
35
+ urls = list()
36
+ for relevant_doc in relevant_docs[1:]:
37
+ content = relevant_doc.page_content
38
+ url = content.split("<BEGIN_URL>")[-1].split("<END_URL>")[0]
39
+ abstract = content.split("\\n")[-1].split("<BEGIN_URL>")[0]
40
+ title = content.split("\\n")[0]
41
+
42
+ abstracts.append(abstract)
43
+ titles.append(title)
44
+ urls.append(url)
45
+ return pd.DataFrame({"title": titles, "abstract": abstracts, "url": urls})
46
+
47
+
48
+ if __name__ == "__main__":
49
+ vector_store_path = "db"
50
+ embedding_model_name = "BAAI/bge-m3"
51
+ vector_store = create_vector_store(
52
+ vector_store_path,
53
+ embedding_model_name,
54
+ )
55
+
56
+ st.markdown("## ICLR2025")
57
+ input_text = st.text_input("query", "", placeholder="")
58
+ top_k = st.number_input("top_k", min_value=1, value=10, step=1)
59
+
60
+ if st.button("検索"):
61
+ stripped_input_text = input_text.strip()
62
+ df = grab_topk(
63
+ stripped_input_text,
64
+ vector_store,
65
+ top_k,
66
+ )
67
+ st.table(df)