NINU / agent.py
Ali-Developments's picture
Update agent.py
63de395 verified
raw
history blame
4 kB
import os
from dotenv import load_dotenv
from langchain.docstore.document import Document
from langchain_community.retrievers import BM25Retriever
from langchain.tools import Tool
from langchain.utilities import SerpAPIWrapper
from langgraph.graph.message import add_messages
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import AnyMessage, HumanMessage
from langchain_groq import ChatGroq
from typing import TypedDict, Annotated
import fitz # PyMuPDF
# Load environment variables
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
serpapi_api_key = os.getenv("SERPAPI_API_KEY")
# --- PDF parsing ---
def parse_pdfs(uploaded_files):
pdf_docs = []
for uploaded_file in uploaded_files:
with fitz.open(stream=uploaded_file.read(), filetype="pdf") as doc:
text = ""
for page in doc:
text += page.get_text()
pdf_docs.append(Document(page_content=text, metadata={"source": uploaded_file.name}))
return pdf_docs
# --- BM25 Retrieval ---
def build_retriever(all_docs):
return BM25Retriever.from_documents(all_docs)
def extract_text(query: str, retriever):
results = retriever.invoke(query)
if results:
return "\n\n".join([doc.page_content for doc in results[:3]])
else:
return "ู„ู… ูŠุชู… ุงู„ุนุซูˆุฑ ุนู„ู‰ ู…ุนู„ูˆู…ุงุช ู…ุทุงุจู‚ุฉ ููŠ ุงู„ู…ู„ูุงุช."
# --- Create NINU Agent ---
def create_ninu_agent(user_docs=None):
bm25_retriever = build_retriever(user_docs) if user_docs else None
def pdf_tool_func(q):
if bm25_retriever:
return extract_text(q, bm25_retriever)
else:
return "ู„ุง ุชูˆุฌุฏ ู…ู„ูุงุช PDF ู…ุฑููˆุนุฉ ู„ู„ุจุญุซ."
NINU_tool = Tool(
name="NINU_Lec_retriever",
func=pdf_tool_func,
description="Retrieves content from uploaded PDFs based on a query."
)
serpapi = SerpAPIWrapper(serpapi_api_key=serpapi_api_key)
SerpAPI_tool = Tool(
name="WebSearch",
func=serpapi.run,
description="Searches the web for recent information."
)
tools = [NINU_tool, SerpAPI_tool]
llm = ChatGroq(model="deepseek-r1-distill-llama-70b", groq_api_key=groq_api_key)
llm_with_tools = llm.bind_tools(tools)
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
def assistant(state: AgentState):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
builder = StateGraph(AgentState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
# --- Main interaction function ---
def run_ninu(query, user_docs=None):
agent = create_ninu_agent(user_docs)
conversation = []
intro_prompt = """
You are a general AI assistant with access to two tools:
1. NINU_Lec_retriever: retrieves content from uploaded PDFs based on a query.
2. WebSearch: performs web searches to answer questions about current events or general knowledge.
Based on the user's query, decide whether to use NINU_Lec_retriever, WebSearch, or both.
When answering, report your thoughts and finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use commas or units (like $, %, etc.) unless specified.
If you are asked for a string, avoid articles, abbreviations, and write digits in plain text unless specified.
"""
conversation.append(HumanMessage(content=intro_prompt))
conversation.append(HumanMessage(content=query))
response = agent.invoke({"messages": conversation})
return response["messages"][-1].content