|
from langgraph.graph import END, StateGraph, START |
|
from langchain_core.prompts import PromptTemplate |
|
from src.agents.state import GraphState |
|
|
|
import asyncio |
|
from src.vectorstore.pinecone_db import get_retriever |
|
from src.tools.web_search import AdvancedWebCrawler |
|
from src.llm.graders import ( |
|
grade_document_relevance, |
|
check_hallucination, |
|
grade_answer_quality |
|
) |
|
from langchain_core.output_parsers import StrOutputParser |
|
from src.llm.query_rewriter import rewrite_query |
|
from langchain_ollama import ChatOllama |
|
|
|
def perform_web_search(question: str): |
|
""" |
|
Perform web search using the AdvancedWebCrawler. |
|
|
|
Args: |
|
question (str): User's input question |
|
|
|
Returns: |
|
List: Web search results |
|
""" |
|
|
|
crawler = AdvancedWebCrawler( |
|
max_search_results=5, |
|
word_count_threshold=50, |
|
content_filter_type='f', |
|
filter_threshold=0.48 |
|
) |
|
results = asyncio.run(crawler.search_and_crawl(question)) |
|
|
|
return results |
|
|
|
|
|
def create_adaptive_rag_workflow(retriever, llm, top_k=5, enable_websearch=False): |
|
""" |
|
Create the adaptive RAG workflow graph. |
|
|
|
Args: |
|
retriever: Vector store retriever |
|
|
|
Returns: |
|
Compiled LangGraph workflow |
|
""" |
|
def retrieve(state: GraphState): |
|
"""Retrieve documents from vectorstore.""" |
|
print("---RETRIEVE---") |
|
question = state['question'] |
|
documents = retriever.invoke(question, top_k) |
|
print(f"Retrieved {len(documents)} documents.") |
|
print(documents) |
|
return {"documents": documents, "question": question} |
|
|
|
def route_to_datasource(state: GraphState): |
|
"""Route question to web search or vectorstore.""" |
|
print("---ROUTE QUESTION---") |
|
|
|
|
|
|
|
if enable_websearch: |
|
print("---ROUTE TO WEB SEARCH---") |
|
return "web_search" |
|
else: |
|
print("---ROUTE TO RAG---") |
|
return "vectorstore" |
|
|
|
def generate_answer(state: GraphState): |
|
"""Generate answer using retrieved documents.""" |
|
print("---GENERATE---") |
|
question = state['question'] |
|
documents = state['documents'] |
|
|
|
|
|
context = "\n\n".join([doc["page_content"] for doc in documents]) |
|
prompt_template = PromptTemplate.from_template("""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. |
|
Question: {question} |
|
Context: {context} |
|
Answer:""") |
|
|
|
rag_chain = prompt_template | llm | StrOutputParser() |
|
|
|
generation = rag_chain.invoke({"context": context, "question": question}) |
|
|
|
return {"generation": generation, "documents": documents, "question": question} |
|
|
|
def grade_documents(state: GraphState): |
|
"""Filter relevant documents.""" |
|
print("---GRADE DOCUMENTS---") |
|
question = state['question'] |
|
documents = state['documents'] |
|
|
|
|
|
filtered_docs = [] |
|
for doc in documents: |
|
score = grade_document_relevance(question, doc["page_content"], llm) |
|
if score == "yes": |
|
filtered_docs.append(doc) |
|
|
|
return {"documents": filtered_docs, "question": question} |
|
|
|
def web_search(state: GraphState): |
|
"""Perform web search.""" |
|
print("---WEB SEARCH---") |
|
question = state['question'] |
|
|
|
|
|
results = perform_web_search(question) |
|
web_documents = [ |
|
{ |
|
"page_content": result['content'], |
|
"metadata": {"source": result['url']} |
|
} for result in results |
|
] |
|
|
|
return {"documents": web_documents, "question": question} |
|
|
|
def check_generation_quality(state: GraphState): |
|
"""Check the quality of generated answer.""" |
|
print("---ASSESS GENERATION---") |
|
question = state['question'] |
|
documents = state['documents'] |
|
generation = state['generation'] |
|
|
|
|
|
print("---Generation is not hallucinated.---") |
|
|
|
quality_score = grade_answer_quality(question, generation, llm) |
|
if quality_score == "yes": |
|
print("---Answer quality is good.---") |
|
else: |
|
print("---Answer quality is poor.---") |
|
return "end" if quality_score == "yes" else "rewrite" |
|
|
|
|
|
workflow = StateGraph(GraphState) |
|
|
|
|
|
workflow.add_node("vectorstore", retrieve) |
|
workflow.add_node("web_search", web_search) |
|
workflow.add_node("grade_documents", grade_documents) |
|
workflow.add_node("generate", generate_answer) |
|
workflow.add_node("rewrite_query", lambda state: { |
|
"question": rewrite_query(state['question'], llm), |
|
"documents": [], |
|
"generation": None |
|
}) |
|
|
|
|
|
workflow.add_conditional_edges( |
|
START, |
|
route_to_datasource, |
|
{ |
|
"web_search": "web_search", |
|
"vectorstore": "vectorstore" |
|
} |
|
) |
|
|
|
workflow.add_edge("web_search", "generate") |
|
workflow.add_edge("vectorstore", "grade_documents") |
|
|
|
workflow.add_conditional_edges( |
|
"grade_documents", |
|
lambda state: "generate" if state['documents'] else "rewrite_query" |
|
) |
|
|
|
workflow.add_edge("rewrite_query", "vectorstore") |
|
|
|
workflow.add_conditional_edges( |
|
"generate", |
|
check_generation_quality, |
|
{ |
|
"end": END, |
|
"regenerate": "generate", |
|
"rewrite": "rewrite_query" |
|
} |
|
) |
|
|
|
|
|
app = workflow.compile() |
|
return app |
|
|
|
def run_adaptive_rag(retriever, question: str, llm, top_k=5, enable_websearch=False): |
|
""" |
|
Run the adaptive RAG workflow for a given question. |
|
|
|
Args: |
|
retriever: Vector store retriever |
|
question (str): User's input question |
|
|
|
Returns: |
|
str: Generated answer |
|
""" |
|
|
|
workflow = create_adaptive_rag_workflow(retriever, llm, top_k, enable_websearch=enable_websearch) |
|
|
|
|
|
final_state = None |
|
for output in workflow.stream({"question": question}, config={"recursion_limit": 5}): |
|
for key, value in output.items(): |
|
print(f"Node '{key}':") |
|
|
|
|
|
final_state = value |
|
|
|
return final_state.get('generation', 'No answer could be generated.') |
|
|
|
if __name__ == "__main__": |
|
|
|
from vectorstore.pinecone_db import PINECONE_API_KEY, ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet |
|
from pinecone import Pinecone |
|
|
|
|
|
pc = Pinecone(api_key=PINECONE_API_KEY) |
|
|
|
|
|
file_paths=[ |
|
|
|
|
|
|
|
'./data/UNIT 2 GENDER BASED VIOLENCE.pptx' |
|
] |
|
|
|
|
|
try: |
|
|
|
print("Loading documents...") |
|
markdown_path = load_documents(file_paths) |
|
|
|
|
|
print("Processing chunks...") |
|
chunks = process_chunks(markdown_path) |
|
|
|
|
|
print("Saving to Parquet...") |
|
parquet_path = save_to_parquet(chunks) |
|
|
|
|
|
print("Ingesting into Pinecone...") |
|
ingest_data(pc, |
|
parquet_path=parquet_path, |
|
text_column="text", |
|
pinecone_client=pc, |
|
) |
|
|
|
|
|
print("\nTesting retrieval...") |
|
retriever = get_retriever( |
|
pinecone_client=pc, |
|
index_name="vector-index", |
|
namespace="rag" |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error in pipeline: {str(e)}") |
|
|
|
llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5) |
|
|
|
|
|
test_questions = [ |
|
|
|
|
|
|
|
"what are the trending papers that are published in NeurIPS 2024?" |
|
] |
|
|
|
|
|
for question in test_questions: |
|
print(f"\n--- Processing Question: {question} ---") |
|
answer = run_adaptive_rag(retriever, question, llm) |
|
print("\nFinal Answer:", answer) |