Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import streamlit as st | |
from dotenv import load_dotenv | |
from laas import ChatLaaS | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever | |
from langchain.retrievers.document_compressors import ( | |
CrossEncoderReranker, | |
FlashrankRerank, | |
) | |
from langchain_core.vectorstores import VectorStore | |
from langchain.storage import LocalFileStore | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
from langchain_community.document_loaders.generic import GenericLoader | |
from langchain_community.document_loaders.parsers.language.language_parser import ( | |
LanguageParser, | |
) | |
from langchain_community.retrievers import BM25Retriever | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter | |
# Load environment variables | |
load_dotenv() | |
# Set up environment variables | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_PROJECT"] = "Code QA Bot" | |
def setup_embeddings_and_db(project_folder: str): # Note the underscore before 'docs' | |
CACHE_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".cache") | |
CACHE_MODELS_PATH = os.path.join(CACHE_ROOT_PATH, "models") | |
CACHE_EMBEDDINGS_PATH = os.path.join(CACHE_ROOT_PATH, "embeddings") | |
if not os.path.exists(CACHE_MODELS_PATH): | |
os.makedirs(CACHE_MODELS_PATH) | |
if not os.path.exists(CACHE_EMBEDDINGS_PATH): | |
os.makedirs(CACHE_EMBEDDINGS_PATH) | |
store = LocalFileStore(CACHE_EMBEDDINGS_PATH) | |
model_name = "BAAI/bge-m3" | |
model_kwargs = {"device": "mps"} | |
encode_kwargs = {"normalize_embeddings": False} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs, | |
cache_folder=CACHE_MODELS_PATH, | |
multi_process=False, | |
show_progress=True, | |
) | |
cached_embeddings = CacheBackedEmbeddings.from_bytes_store( | |
embeddings, | |
store, | |
namespace=embeddings.model_name, | |
) | |
FAISS_DB_INDEX = os.path.join(project_folder, "langchain_faiss") | |
db = FAISS.load_local( | |
FAISS_DB_INDEX, # ๋ก๋ํ FAISS ์ธ๋ฑ์ค์ ๋๋ ํ ๋ฆฌ ์ด๋ฆ | |
cached_embeddings, # ์๋ฒ ๋ฉ ์ ๋ณด๋ฅผ ์ ๊ณต | |
allow_dangerous_deserialization=True, # ์ญ์ง๋ ฌํ๋ฅผ ํ์ฉํ๋ ์ต์ | |
) | |
return db | |
# Function to set up retrievers and chain | |
def setup_retrievers_and_chain( | |
_db: VectorStore, project_folder: str | |
): # Note the underscores | |
faiss_retriever = _db.as_retriever(search_type="mmr", search_kwargs={"k": 20}) | |
bm25_retriever_path = os.path.join(project_folder, "bm25_retriever.pkl") | |
with open(bm25_retriever_path, "rb") as f: | |
bm25_retriever = pickle.load(f) | |
bm25_retriever.k = 20 | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[bm25_retriever, faiss_retriever], | |
weights=[0.6, 0.4], | |
search_type="mmr", | |
) | |
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3") | |
compressor = CrossEncoderReranker(model=model, top_n=5) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, | |
base_retriever=ensemble_retriever, | |
) | |
laas = ChatLaaS( | |
project=st.secrets["LAAS_PROJECT"], | |
api_key=st.secrets["LAAS_API_KEY"], | |
hash=st.secrets["LAAS_HASH"], | |
) | |
rag_chain = ( | |
{ | |
"context": compression_retriever | RunnableLambda(lambda x: str(x)), | |
"question": RunnablePassthrough(), | |
} | |
| RunnableLambda( | |
lambda x: laas.invoke( | |
"", params={"context": x["context"], "question": x["question"]} | |
) | |
) | |
| StrOutputParser() | |
) | |
return rag_chain | |
def main(): | |
st.title("Code QA Bot") | |
# Initialize session state for project folder and answer | |
if "project_folder" not in st.session_state: | |
st.session_state.project_folder = "" | |
if "answer" not in st.session_state: | |
st.session_state.answer = "" | |
# ํ๋ก์ ํธ ๊ฒฝ๋ก ์ ๋ ฅ ๋ฐ๊ธฐ | |
project_folder = st.text_input( | |
"Enter the project folder path:", value=st.session_state.project_folder | |
) | |
st.session_state.project_folder = project_folder | |
if project_folder: | |
# ํ๋ก์ ํธ ๊ฒฝ๋ก๊ฐ ์ ๋ ฅ๋๋ฉด ๋ฒกํฐ ์คํ ์ด์ ์ฒด์ธ ์ค์ | |
db = setup_embeddings_and_db(project_folder) | |
rag_chain = setup_retrievers_and_chain(db, project_folder) | |
# ์ฌ์ฉ์ ์ง๋ฌธ ์ ๋ ฅ ๋ฐ๊ธฐ | |
user_question = st.text_input("Ask a question about the code:") | |
# Add a button to reset the answer | |
if st.button("Reset Answer"): | |
st.session_state.answer = "" | |
if user_question: | |
with st.spinner("Generating answer..."): | |
response = rag_chain.invoke(user_question) | |
st.session_state.answer = response | |
# Display the answer | |
if st.session_state.answer: | |
st.write(st.session_state.answer) | |
else: | |
st.warning("Please enter the project folder path to proceed.") | |
if __name__ == "__main__": | |
main() | |