open-webui-rag-system / rag_system.py
hugging2021's picture
Update rag_system.py
801b168 verified
import os
import argparse
import sys
from langchain_community.chains import RetrievalQA
from langchain_communit.prompts import PromptTemplate
from vector_store import get_embeddings, load_vector_store
from llm_loader import load_llama_model
def create_refine_prompts_with_pages(language="de"):
if language == "de":
question_prompt = PromptTemplate(
input_variables=["context_str", "question"],
template="""
๋‹ค์Œ์€ ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ์กฐ๊ฐ๋“ค์ž…๋‹ˆ๋‹ค:
{context_str}
์œ„ ๋ฌธ์„œ๋“ค์„ ์ฐธ๊ณ ํ•˜์—ฌ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
**์ค‘์š”ํ•œ ๊ทœ์น™:**
- ๋‹ต๋ณ€ ์‹œ ์ฐธ๊ณ ํ•œ ๋ฌธ์„œ๊ฐ€ ์žˆ๋‹ค๋ฉด ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ธ์šฉํ•˜์„ธ์š”
- ๋ฌธ์„œ์— ๋ช…์‹œ๋œ ์ •๋ณด๋งŒ ์‚ฌ์šฉํ•˜๊ณ , ์ถ”์ธกํ•˜์ง€ ๋งˆ์„ธ์š”
- ํŽ˜์ด์ง€ ๋ฒˆํ˜ธ๋‚˜ ์ถœ์ฒ˜๋Š” ์œ„ ๋ฌธ์„œ์—์„œ ํ™•์ธ๋œ ๊ฒƒ๋งŒ ์–ธ๊ธ‰ํ•˜์„ธ์š”
- ํ™•์‹คํ•˜์ง€ ์•Š์€ ์ •๋ณด๋Š” "๋ฌธ์„œ์—์„œ ํ™•์ธ๋˜์ง€ ์•Š์Œ"์ด๋ผ๊ณ  ๋ช…์‹œํ•˜์„ธ์š”
์งˆ๋ฌธ: {question}
๋‹ต๋ณ€:"""
)
refine_prompt = PromptTemplate(
input_variables=["question", "existing_answer", "context_str"],
template="""
๊ธฐ์กด ๋‹ต๋ณ€:
{existing_answer}
์ถ”๊ฐ€ ๋ฌธ์„œ:
{context_str}
๊ธฐ์กด ๋‹ต๋ณ€์„ ์œ„ ์ถ”๊ฐ€ ๋ฌธ์„œ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋ณด์™„ํ•˜๊ฑฐ๋‚˜ ์ˆ˜์ •ํ•ด์ฃผ์„ธ์š”.
**๊ทœ์น™:**
- ์ƒˆ๋กœ์šด ์ •๋ณด๊ฐ€ ๊ธฐ์กด ๋‹ต๋ณ€๊ณผ ๋‹ค๋ฅด๋‹ค๋ฉด ์ˆ˜์ •ํ•˜์„ธ์š”
- ์ถ”๊ฐ€ ๋ฌธ์„œ์— ๋ช…์‹œ๋œ ์ •๋ณด๋งŒ ์‚ฌ์šฉํ•˜์„ธ์š”
- ํ•˜๋‚˜์˜ ์™„๊ฒฐ๋œ ๋‹ต๋ณ€์œผ๋กœ ์ž‘์„ฑํ•˜์„ธ์š”
- ํ™•์‹คํ•˜์ง€ ์•Š์€ ์ถœ์ฒ˜๋‚˜ ํŽ˜์ด์ง€๋Š” ์–ธ๊ธ‰ํ•˜์ง€ ๋งˆ์„ธ์š”
์งˆ๋ฌธ: {question}
๋‹ต๋ณ€:"""
)
else:
question_prompt = PromptTemplate(
input_variables=["context_str", "question"],
template="""
Here are the retrieved document fragments:
{context_str}
Please answer the question based on the above documents.
**Important rules:**
- Only use information explicitly stated in the documents
- If citing sources, only mention what is clearly indicated in the documents above
- Do not guess or infer page numbers not shown in the context
- If unsure, state "not confirmed in the provided documents"
Question: {question}
Answer:"""
)
refine_prompt = PromptTemplate(
input_variables=["question", "existing_answer", "context_str"],
template="""
Existing answer:
{existing_answer}
Additional documents:
{context_str}
Refine the existing answer using the additional documents.
**Rules:**
- Only use information explicitly stated in the additional documents
- Create one coherent final answer
- Do not mention uncertain sources or page numbers
Question: {question}
Answer:"""
)
return question_prompt, refine_prompt
def build_rag_chain(llm, vectorstore, language="ko", k=7):
"""RAG ์ฒด์ธ ๊ตฌ์ถ•"""
question_prompt, refine_prompt = create_refine_prompts_with_pages(language)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="refine",
retriever=vectorstore.as_retriever(search_kwargs={"k": k}),
chain_type_kwargs={
"question_prompt": question_prompt,
"refine_prompt": refine_prompt
},
return_source_documents=True
)
return qa_chain
def ask_question_with_pages(qa_chain, question):
"""์งˆ๋ฌธ ์ฒ˜๋ฆฌ"""
result = qa_chain.invoke({"query": question})
# ๊ฒฐ๊ณผ์—์„œ A: ์ดํ›„ ๋ฌธ์žฅ๋งŒ ์ถ”์ถœ
answer = result['result']
final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip()
print(f"\n๐Ÿงพ ์งˆ๋ฌธ: {question}")
print(f"\n๐ŸŸข ์ตœ์ข… ๋‹ต๋ณ€: {final_answer}")
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ๋””๋ฒ„๊น… ์ •๋ณด ์ถœ๋ ฅ (๋น„ํ™œ์„ฑํ™”)
# debug_metadata_info(result["source_documents"])
# ์ฐธ๊ณ  ๋ฌธ์„œ๋ฅผ ํŽ˜์ด์ง€๋ณ„๋กœ ์ •๋ฆฌ
print("\n๐Ÿ“š ์ฐธ๊ณ  ๋ฌธ์„œ ์š”์•ฝ:")
source_info = {}
for doc in result["source_documents"]:
source = doc.metadata.get('source', 'N/A')
page = doc.metadata.get('page', 'N/A')
doc_type = doc.metadata.get('type', 'N/A')
section = doc.metadata.get('section', None)
total_pages = doc.metadata.get('total_pages', None)
filename = doc.metadata.get('filename', 'N/A')
if filename == 'N/A':
filename = os.path.basename(source) if source != 'N/A' else 'N/A'
if filename not in source_info:
source_info[filename] = {
'pages': set(),
'sections': set(),
'types': set(),
'total_pages': total_pages
}
if page != 'N/A':
if isinstance(page, str) and page.startswith('์„น์…˜'):
source_info[filename]['sections'].add(page)
else:
source_info[filename]['pages'].add(page)
if section is not None:
source_info[filename]['sections'].add(f"์„น์…˜ {section}")
source_info[filename]['types'].add(doc_type)
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
total_chunks = len(result["source_documents"])
print(f"์ด ์‚ฌ์šฉ๋œ ์ฒญํฌ ์ˆ˜: {total_chunks}")
for filename, info in source_info.items():
print(f"\n- {filename}")
# ์ „์ฒด ํŽ˜์ด์ง€ ์ˆ˜ ์ •๋ณด
if info['total_pages']:
print(f" ์ „์ฒด ํŽ˜์ด์ง€ ์ˆ˜: {info['total_pages']}")
# ํŽ˜์ด์ง€ ์ •๋ณด ์ถœ๋ ฅ
if info['pages']:
pages_list = list(info['pages'])
print(f" ํŽ˜์ด์ง€: {', '.join(map(str, pages_list))}")
# ์„น์…˜ ์ •๋ณด ์ถœ๋ ฅ
if info['sections']:
sections_list = sorted(list(info['sections']))
print(f" ์„น์…˜: {', '.join(sections_list)}")
# ํŽ˜์ด์ง€์™€ ์„น์…˜์ด ๋ชจ๋‘ ์—†๋Š” ๊ฒฝ์šฐ
if not info['pages'] and not info['sections']:
print(f" ํŽ˜์ด์ง€: ์ •๋ณด ์—†์Œ")
# ๋ฌธ์„œ ์œ ํ˜• ์ถœ๋ ฅ
types_str = ', '.join(sorted(info['types']))
print(f" ์œ ํ˜•: {types_str}")
return result
# ๊ธฐ์กด ask_question ํ•จ์ˆ˜๋Š” ask_question_with_pages๋กœ ๊ต์ฒด
def ask_question(qa_chain, question):
"""ํ˜ธํ™˜์„ฑ์„ ์œ„ํ•œ ๋ž˜ํผ ํ•จ์ˆ˜"""
return ask_question_with_pages(qa_chain, question)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RAG refine system (ํŽ˜์ด์ง€ ๋ฒˆํ˜ธ ์ง€์›)")
parser.add_argument("--vector_store", type=str, default="vector_db", help="๋ฒกํ„ฐ ์Šคํ† ์–ด ๊ฒฝ๋กœ")
parser.add_argument("--model", type=str, default="LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct", help="LLM ๋ชจ๋ธ ID")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="์‚ฌ์šฉํ•  ๋””๋ฐ”์ด์Šค")
parser.add_argument("--k", type=int, default=7, help="๊ฒ€์ƒ‰ํ•  ๋ฌธ์„œ ์ˆ˜")
parser.add_argument("--language", type=str, default="ko", choices=["ko", "en"], help="์‚ฌ์šฉํ•  ์–ธ์–ด")
parser.add_argument("--query", type=str, help="์งˆ๋ฌธ (์—†์œผ๋ฉด ๋Œ€ํ™”ํ˜• ๋ชจ๋“œ ์‹คํ–‰)")
args = parser.parse_args()
embeddings = get_embeddings(device=args.device)
vectorstore = load_vector_store(embeddings, load_path=args.vector_store)
llm = load_llama_model()
qa_chain = build_rag_chain(llm, vectorstore, language=args.language, k=args.k)
print("๐ŸŸข RAG ํŽ˜์ด์ง€ ๋ฒˆํ˜ธ ์ง€์› ์‹œ์Šคํ…œ ์ค€๋น„ ์™„๋ฃŒ!")
if args.query:
ask_question_with_pages(qa_chain, args.query)
else:
print("๐Ÿ’ฌ ๋Œ€ํ™”ํ˜• ๋ชจ๋“œ ์‹œ์ž‘ (์ข…๋ฃŒํ•˜๋ ค๋ฉด 'exit', 'quit', '์ข…๋ฃŒ' ์ž…๋ ฅ)")
while True:
try:
query = input("\n์งˆ๋ฌธ: ").strip()
if query.lower() in ["exit", "quit", "์ข…๋ฃŒ"]:
break
if query: # ๋นˆ ์ž…๋ ฅ ๋ฐฉ์ง€
ask_question_with_pages(qa_chain, query)
except KeyboardInterrupt:
print("\n\nํ”„๋กœ๊ทธ๋žจ์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
break
except Exception as e:
print(f"โ— ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}\n๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”.")