File size: 5,428 Bytes
e7055d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import pickle

import streamlit as st
from dotenv import load_dotenv
from laas import ChatLaaS
from langchain.embeddings import CacheBackedEmbeddings
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain.retrievers.document_compressors import (
    CrossEncoderReranker,
    FlashrankRerank,
)
from langchain_core.vectorstores import VectorStore
from langchain.storage import LocalFileStore
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers.language.language_parser import (
    LanguageParser,
)
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter

# Load environment variables
load_dotenv()

# Set up environment variables
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Code QA Bot"


@st.cache_resource
def setup_embeddings_and_db(project_folder: str):  # Note the underscore before 'docs'
    CACHE_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".cache")
    CACHE_MODELS_PATH = os.path.join(CACHE_ROOT_PATH, "models")
    CACHE_EMBEDDINGS_PATH = os.path.join(CACHE_ROOT_PATH, "embeddings")

    if not os.path.exists(CACHE_MODELS_PATH):
        os.makedirs(CACHE_MODELS_PATH)
    if not os.path.exists(CACHE_EMBEDDINGS_PATH):
        os.makedirs(CACHE_EMBEDDINGS_PATH)

    store = LocalFileStore(CACHE_EMBEDDINGS_PATH)

    model_name = "BAAI/bge-m3"
    model_kwargs = {"device": "mps"}
    encode_kwargs = {"normalize_embeddings": False}
    embeddings = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs,
        cache_folder=CACHE_MODELS_PATH,
        multi_process=False,
        show_progress=True,
    )

    cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
        embeddings,
        store,
        namespace=embeddings.model_name,
    )

    FAISS_DB_INDEX = os.path.join(project_folder, "langchain_faiss")
    db = FAISS.load_local(
        FAISS_DB_INDEX,  # ๋กœ๋“œํ•  FAISS ์ธ๋ฑ์Šค์˜ ๋””๋ ‰ํ† ๋ฆฌ ์ด๋ฆ„
        cached_embeddings,  # ์ž„๋ฒ ๋”ฉ ์ •๋ณด๋ฅผ ์ œ๊ณต
        allow_dangerous_deserialization=True,  # ์—ญ์ง๋ ฌํ™”๋ฅผ ํ—ˆ์šฉํ•˜๋Š” ์˜ต์…˜
    )

    return db


# Function to set up retrievers and chain
@st.cache_resource
def setup_retrievers_and_chain(
    _db: VectorStore, project_folder: str
):  # Note the underscores
    faiss_retriever = _db.as_retriever(search_type="mmr", search_kwargs={"k": 20})

    bm25_retriever_path = os.path.join(project_folder, "bm25_retriever.pkl")
    with open(bm25_retriever_path, "rb") as f:
        bm25_retriever = pickle.load(f)
        bm25_retriever.k = 20

    ensemble_retriever = EnsembleRetriever(
        retrievers=[bm25_retriever, faiss_retriever],
        weights=[0.6, 0.4],
        search_type="mmr",
    )

    model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
    compressor = CrossEncoderReranker(model=model, top_n=5)
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=compressor,
        base_retriever=ensemble_retriever,
    )

    laas = ChatLaaS(
        project=st.secrets["LAAS_PROJECT"],
        api_key=st.secrets["LAAS_API_KEY"],
        hash=st.secrets["LAAS_HASH"],
    )

    rag_chain = (
        {
            "context": compression_retriever | RunnableLambda(lambda x: str(x)),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(
            lambda x: laas.invoke(
                "", params={"context": x["context"], "question": x["question"]}
            )
        )
        | StrOutputParser()
    )

    return rag_chain


def main():
    st.title("Code QA Bot")

    # Initialize session state for project folder and answer
    if "project_folder" not in st.session_state:
        st.session_state.project_folder = ""
    if "answer" not in st.session_state:
        st.session_state.answer = ""

    # ํ”„๋กœ์ ํŠธ ๊ฒฝ๋กœ ์ž…๋ ฅ ๋ฐ›๊ธฐ
    project_folder = st.text_input(
        "Enter the project folder path:", value=st.session_state.project_folder
    )
    st.session_state.project_folder = project_folder

    if project_folder:
        # ํ”„๋กœ์ ํŠธ ๊ฒฝ๋กœ๊ฐ€ ์ž…๋ ฅ๋˜๋ฉด ๋ฒกํ„ฐ ์Šคํ† ์–ด์™€ ์ฒด์ธ ์„ค์ •
        db = setup_embeddings_and_db(project_folder)
        rag_chain = setup_retrievers_and_chain(db, project_folder)

        # ์‚ฌ์šฉ์ž ์งˆ๋ฌธ ์ž…๋ ฅ ๋ฐ›๊ธฐ
        user_question = st.text_input("Ask a question about the code:")

        # Add a button to reset the answer
        if st.button("Reset Answer"):
            st.session_state.answer = ""

        if user_question:
            with st.spinner("Generating answer..."):
                response = rag_chain.invoke(user_question)
                st.session_state.answer = response

        # Display the answer
        if st.session_state.answer:
            st.write(st.session_state.answer)
    else:
        st.warning("Please enter the project folder path to proceed.")


if __name__ == "__main__":
    main()