ryota39's picture
Upload app.py
6a6ca54 verified
raw
history blame
1.9 kB
import os
import pandas as pd
import streamlit as st
from langchain_community.vectorstores.faiss import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
@st.cache_resource
def create_vector_store(
vector_store_path: str,
embedding_model_name: str,
) -> FAISS:
embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name)
vector_store = FAISS.load_local(
folder_path=vector_store_path,
embeddings=embedding_model,
allow_dangerous_deserialization=True,
)
return vector_store
def grab_topk(
input_text: str,
vector_store: FAISS,
top_k: int,
) -> pd.DataFrame:
retriever = vector_store.as_retriever(search_kwargs={"k": top_k + 1})
relevant_docs = retriever.get_relevant_documents(input_text)
abstracts = list()
titles = list()
urls = list()
for relevant_doc in relevant_docs[1:]:
content = relevant_doc.page_content
url = content.split("<BEGIN_URL>")[-1].split("<END_URL>")[0]
abstract = content.split("\\n")[-1].split("<BEGIN_URL>")[0]
title = content.split("\\n")[0]
abstracts.append(abstract)
titles.append(title)
urls.append(url)
return pd.DataFrame({"title": titles, "abstract": abstracts, "url": urls})
if __name__ == "__main__":
vector_store_path = "db"
embedding_model_name = "BAAI/bge-m3"
vector_store = create_vector_store(
vector_store_path,
embedding_model_name,
)
st.markdown("## ICLR2025")
input_text = st.text_input("query", "", placeholder="")
top_k = st.number_input("top_k", min_value=1, value=10, step=1)
if st.button("検索"):
stripped_input_text = input_text.strip()
df = grab_topk(
stripped_input_text,
vector_store,
top_k,
)
st.table(df)