File size: 1,899 Bytes
6a6ca54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import pandas as pd
import streamlit as st
from langchain_community.vectorstores.faiss import FAISS
from langchain_huggingface import HuggingFaceEmbeddings


os.environ["KMP_DUPLICATE_LIB_OK"] = "True"


@st.cache_resource
def create_vector_store(
    vector_store_path: str,
    embedding_model_name: str,
) -> FAISS:
    embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name)
    vector_store = FAISS.load_local(
        folder_path=vector_store_path,
        embeddings=embedding_model,
        allow_dangerous_deserialization=True,
    )
    return vector_store


def grab_topk(
    input_text: str,
    vector_store: FAISS,
    top_k: int,
) -> pd.DataFrame:
    retriever = vector_store.as_retriever(search_kwargs={"k": top_k + 1})
    relevant_docs = retriever.get_relevant_documents(input_text)

    abstracts = list()
    titles = list()
    urls = list()
    for relevant_doc in relevant_docs[1:]:
        content = relevant_doc.page_content
        url = content.split("<BEGIN_URL>")[-1].split("<END_URL>")[0]
        abstract = content.split("\\n")[-1].split("<BEGIN_URL>")[0]
        title = content.split("\\n")[0]

        abstracts.append(abstract)
        titles.append(title)
        urls.append(url)
    return pd.DataFrame({"title": titles, "abstract": abstracts, "url": urls})


if __name__ == "__main__":
    vector_store_path = "db"
    embedding_model_name = "BAAI/bge-m3"
    vector_store = create_vector_store(
        vector_store_path,
        embedding_model_name,
    )

    st.markdown("## ICLR2025")
    input_text = st.text_input("query", "", placeholder="")
    top_k = st.number_input("top_k", min_value=1, value=10, step=1)

    if st.button("検索"):
        stripped_input_text = input_text.strip()
        df = grab_topk(
            stripped_input_text,
            vector_store,
            top_k,
        )
        st.table(df)