Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
import sys | |
from langchain_community.chains import RetrievalQA | |
from langchain_communit.prompts import PromptTemplate | |
from vector_store import get_embeddings, load_vector_store | |
from llm_loader import load_llama_model | |
def create_refine_prompts_with_pages(language="de"): | |
if language == "de": | |
question_prompt = PromptTemplate( | |
input_variables=["context_str", "question"], | |
template=""" | |
๋ค์์ ๊ฒ์๋ ๋ฌธ์ ์กฐ๊ฐ๋ค์ ๋๋ค: | |
{context_str} | |
์ ๋ฌธ์๋ค์ ์ฐธ๊ณ ํ์ฌ ์ง๋ฌธ์ ๋ต๋ณํด์ฃผ์ธ์. | |
**์ค์ํ ๊ท์น:** | |
- ๋ต๋ณ ์ ์ฐธ๊ณ ํ ๋ฌธ์๊ฐ ์๋ค๋ฉด ํด๋น ์ ๋ณด๋ฅผ ์ธ์ฉํ์ธ์ | |
- ๋ฌธ์์ ๋ช ์๋ ์ ๋ณด๋ง ์ฌ์ฉํ๊ณ , ์ถ์ธกํ์ง ๋ง์ธ์ | |
- ํ์ด์ง ๋ฒํธ๋ ์ถ์ฒ๋ ์ ๋ฌธ์์์ ํ์ธ๋ ๊ฒ๋ง ์ธ๊ธํ์ธ์ | |
- ํ์คํ์ง ์์ ์ ๋ณด๋ "๋ฌธ์์์ ํ์ธ๋์ง ์์"์ด๋ผ๊ณ ๋ช ์ํ์ธ์ | |
์ง๋ฌธ: {question} | |
๋ต๋ณ:""" | |
) | |
refine_prompt = PromptTemplate( | |
input_variables=["question", "existing_answer", "context_str"], | |
template=""" | |
๊ธฐ์กด ๋ต๋ณ: | |
{existing_answer} | |
์ถ๊ฐ ๋ฌธ์: | |
{context_str} | |
๊ธฐ์กด ๋ต๋ณ์ ์ ์ถ๊ฐ ๋ฌธ์๋ฅผ ๋ฐํ์ผ๋ก ๋ณด์ํ๊ฑฐ๋ ์์ ํด์ฃผ์ธ์. | |
**๊ท์น:** | |
- ์๋ก์ด ์ ๋ณด๊ฐ ๊ธฐ์กด ๋ต๋ณ๊ณผ ๋ค๋ฅด๋ค๋ฉด ์์ ํ์ธ์ | |
- ์ถ๊ฐ ๋ฌธ์์ ๋ช ์๋ ์ ๋ณด๋ง ์ฌ์ฉํ์ธ์ | |
- ํ๋์ ์๊ฒฐ๋ ๋ต๋ณ์ผ๋ก ์์ฑํ์ธ์ | |
- ํ์คํ์ง ์์ ์ถ์ฒ๋ ํ์ด์ง๋ ์ธ๊ธํ์ง ๋ง์ธ์ | |
์ง๋ฌธ: {question} | |
๋ต๋ณ:""" | |
) | |
else: | |
question_prompt = PromptTemplate( | |
input_variables=["context_str", "question"], | |
template=""" | |
Here are the retrieved document fragments: | |
{context_str} | |
Please answer the question based on the above documents. | |
**Important rules:** | |
- Only use information explicitly stated in the documents | |
- If citing sources, only mention what is clearly indicated in the documents above | |
- Do not guess or infer page numbers not shown in the context | |
- If unsure, state "not confirmed in the provided documents" | |
Question: {question} | |
Answer:""" | |
) | |
refine_prompt = PromptTemplate( | |
input_variables=["question", "existing_answer", "context_str"], | |
template=""" | |
Existing answer: | |
{existing_answer} | |
Additional documents: | |
{context_str} | |
Refine the existing answer using the additional documents. | |
**Rules:** | |
- Only use information explicitly stated in the additional documents | |
- Create one coherent final answer | |
- Do not mention uncertain sources or page numbers | |
Question: {question} | |
Answer:""" | |
) | |
return question_prompt, refine_prompt | |
def build_rag_chain(llm, vectorstore, language="ko", k=7): | |
"""RAG ์ฒด์ธ ๊ตฌ์ถ""" | |
question_prompt, refine_prompt = create_refine_prompts_with_pages(language) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="refine", | |
retriever=vectorstore.as_retriever(search_kwargs={"k": k}), | |
chain_type_kwargs={ | |
"question_prompt": question_prompt, | |
"refine_prompt": refine_prompt | |
}, | |
return_source_documents=True | |
) | |
return qa_chain | |
def ask_question_with_pages(qa_chain, question): | |
"""์ง๋ฌธ ์ฒ๋ฆฌ""" | |
result = qa_chain.invoke({"query": question}) | |
# ๊ฒฐ๊ณผ์์ A: ์ดํ ๋ฌธ์ฅ๋ง ์ถ์ถ | |
answer = result['result'] | |
final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip() | |
print(f"\n๐งพ ์ง๋ฌธ: {question}") | |
print(f"\n๐ข ์ต์ข ๋ต๋ณ: {final_answer}") | |
# ๋ฉํ๋ฐ์ดํฐ ๋๋ฒ๊น ์ ๋ณด ์ถ๋ ฅ (๋นํ์ฑํ) | |
# debug_metadata_info(result["source_documents"]) | |
# ์ฐธ๊ณ ๋ฌธ์๋ฅผ ํ์ด์ง๋ณ๋ก ์ ๋ฆฌ | |
print("\n๐ ์ฐธ๊ณ ๋ฌธ์ ์์ฝ:") | |
source_info = {} | |
for doc in result["source_documents"]: | |
source = doc.metadata.get('source', 'N/A') | |
page = doc.metadata.get('page', 'N/A') | |
doc_type = doc.metadata.get('type', 'N/A') | |
section = doc.metadata.get('section', None) | |
total_pages = doc.metadata.get('total_pages', None) | |
filename = doc.metadata.get('filename', 'N/A') | |
if filename == 'N/A': | |
filename = os.path.basename(source) if source != 'N/A' else 'N/A' | |
if filename not in source_info: | |
source_info[filename] = { | |
'pages': set(), | |
'sections': set(), | |
'types': set(), | |
'total_pages': total_pages | |
} | |
if page != 'N/A': | |
if isinstance(page, str) and page.startswith('์น์ '): | |
source_info[filename]['sections'].add(page) | |
else: | |
source_info[filename]['pages'].add(page) | |
if section is not None: | |
source_info[filename]['sections'].add(f"์น์ {section}") | |
source_info[filename]['types'].add(doc_type) | |
# ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
total_chunks = len(result["source_documents"]) | |
print(f"์ด ์ฌ์ฉ๋ ์ฒญํฌ ์: {total_chunks}") | |
for filename, info in source_info.items(): | |
print(f"\n- {filename}") | |
# ์ ์ฒด ํ์ด์ง ์ ์ ๋ณด | |
if info['total_pages']: | |
print(f" ์ ์ฒด ํ์ด์ง ์: {info['total_pages']}") | |
# ํ์ด์ง ์ ๋ณด ์ถ๋ ฅ | |
if info['pages']: | |
pages_list = list(info['pages']) | |
print(f" ํ์ด์ง: {', '.join(map(str, pages_list))}") | |
# ์น์ ์ ๋ณด ์ถ๋ ฅ | |
if info['sections']: | |
sections_list = sorted(list(info['sections'])) | |
print(f" ์น์ : {', '.join(sections_list)}") | |
# ํ์ด์ง์ ์น์ ์ด ๋ชจ๋ ์๋ ๊ฒฝ์ฐ | |
if not info['pages'] and not info['sections']: | |
print(f" ํ์ด์ง: ์ ๋ณด ์์") | |
# ๋ฌธ์ ์ ํ ์ถ๋ ฅ | |
types_str = ', '.join(sorted(info['types'])) | |
print(f" ์ ํ: {types_str}") | |
return result | |
# ๊ธฐ์กด ask_question ํจ์๋ ask_question_with_pages๋ก ๊ต์ฒด | |
def ask_question(qa_chain, question): | |
"""ํธํ์ฑ์ ์ํ ๋ํผ ํจ์""" | |
return ask_question_with_pages(qa_chain, question) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="RAG refine system (ํ์ด์ง ๋ฒํธ ์ง์)") | |
parser.add_argument("--vector_store", type=str, default="vector_db", help="๋ฒกํฐ ์คํ ์ด ๊ฒฝ๋ก") | |
parser.add_argument("--model", type=str, default="LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct", help="LLM ๋ชจ๋ธ ID") | |
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="์ฌ์ฉํ ๋๋ฐ์ด์ค") | |
parser.add_argument("--k", type=int, default=7, help="๊ฒ์ํ ๋ฌธ์ ์") | |
parser.add_argument("--language", type=str, default="ko", choices=["ko", "en"], help="์ฌ์ฉํ ์ธ์ด") | |
parser.add_argument("--query", type=str, help="์ง๋ฌธ (์์ผ๋ฉด ๋ํํ ๋ชจ๋ ์คํ)") | |
args = parser.parse_args() | |
embeddings = get_embeddings(device=args.device) | |
vectorstore = load_vector_store(embeddings, load_path=args.vector_store) | |
llm = load_llama_model() | |
qa_chain = build_rag_chain(llm, vectorstore, language=args.language, k=args.k) | |
print("๐ข RAG ํ์ด์ง ๋ฒํธ ์ง์ ์์คํ ์ค๋น ์๋ฃ!") | |
if args.query: | |
ask_question_with_pages(qa_chain, args.query) | |
else: | |
print("๐ฌ ๋ํํ ๋ชจ๋ ์์ (์ข ๋ฃํ๋ ค๋ฉด 'exit', 'quit', '์ข ๋ฃ' ์ ๋ ฅ)") | |
while True: | |
try: | |
query = input("\n์ง๋ฌธ: ").strip() | |
if query.lower() in ["exit", "quit", "์ข ๋ฃ"]: | |
break | |
if query: # ๋น ์ ๋ ฅ ๋ฐฉ์ง | |
ask_question_with_pages(qa_chain, query) | |
except KeyboardInterrupt: | |
print("\n\nํ๋ก๊ทธ๋จ์ ์ข ๋ฃํฉ๋๋ค.") | |
break | |
except Exception as e: | |
print(f"โ ์ค๋ฅ ๋ฐ์: {e}\n๋ค์ ์๋ํด์ฃผ์ธ์.") | |