File size: 5,903 Bytes
95f4e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31ef6a9
d9109f7
 
95f4e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from typing import TypedDict, Annotated, Optional
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import START, StateGraph, END 
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
import os
from dotenv import load_dotenv
from tools import download_file_from_url, basic_web_search, extract_url_content, wikipedia_reader, transcribe_audio_file, question_youtube_video

# Load environment variables from .env file
load_dotenv()

OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
MAIN_LLM_MODEL = os.getenv("MAIN_LLM_MODEL", "google/gemini-2.0-flash-lite-001")

# Generate the chat interface, including the tools
if not OPENROUTER_API_KEY:
    raise ValueError("OPENROUTER_API_KEY is not set. Please ensure it is defined in your .env file or environment variables.")


def create_agent_graph():
    
    main_llm = ChatOpenAI(
        model=MAIN_LLM_MODEL, # e.g., "mistralai/mistral-7b-instruct"
        api_key=SecretStr(OPENROUTER_API_KEY), # Your OpenRouter API key
        base_url="https://openrouter.ai/api/v1", # Standard OpenRouter API base
        verbose=True # Optional: for debugging
    )


    tools = [download_file_from_url, basic_web_search, extract_url_content, wikipedia_reader, transcribe_audio_file, question_youtube_video] # Ensure these tools are defined
    chat_with_tools = main_llm.bind_tools(tools)

    class AgentState(TypedDict):
        messages: Annotated[list[AnyMessage], add_messages]
        file_url: Optional[str | None]
        file_ext: Optional[str | None]
        local_file_path: Optional[str | None]
        final_answer: Optional[str | None]
    
    def assistant(state: AgentState):
        return {
            "messages": [chat_with_tools.invoke(state["messages"])],
            "file_url": state.get("file_url", None),
            "file_ext": state.get("file_ext", None),
            "local_file_path": state.get("local_file_path", None),
            "final_answer": state.get("final_answer", None)
        }
    
    def file_path_updater_node(state: AgentState):
        download_tool_response = state["messages"][-1].content
        file_path = download_tool_response.split("Local File Path: ")[-1].strip()
        return {
            "local_file_path": file_path
        }
    
    def file_path_condition(state: AgentState) -> str:
        if state["messages"] and isinstance(state["messages"][-1], ToolMessage):
            tool_response = state["messages"][-1]
            if tool_response.name == "download_file_from_url":
                return "update_file_path"  # Route to file path updater if a file was downloaded
        return "assistant"  # Otherwise, continue with the assistant node

    def format_final_answer_node(state: AgentState) -> AgentState:
        """
        Formats the final answer based on the state.
        This node is reached when the assistant has completed its task.
        """
        final_answer = state["messages"][-1].content if state["messages"] else None
        if final_answer:
            state["final_answer"] = final_answer.split("FINAL ANSWER:")[-1].strip() #if FINAL_ANSWER isn't present we grab the whole string
        return state           
        
        
    # The graph
    builder = StateGraph(AgentState)
    
    builder.add_node("assistant", assistant)
    builder.add_edge(START, "assistant")
    builder.add_node("tools", ToolNode(tools))
    builder.add_node("file_path_updater_node", file_path_updater_node)
    builder.add_node("format_final_answer_node", format_final_answer_node)
    
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
        {
            "tools": "tools", 
            "__end__": "format_final_answer_node"  # This is the end node for the assistant
        }
    )
    builder.add_conditional_edges(
        "tools",
        file_path_condition,
        {
            "update_file_path": "file_path_updater_node",
            "assistant": "assistant"
        }
    )
    
    builder.add_edge("file_path_updater_node", "assistant")
    builder.add_edge("format_final_answer_node", END)
    graph = builder.compile()
    return graph

class BasicAgent:
    """
    A basic agent that can answer questions and download files.
    Requires a system message be defined in 'system_prompt.txt'.
    """
    def __init__(self, graph=None):

        with open("system_prompt.txt", "r", encoding="utf-8") as f:
            self.system_message = SystemMessage(content=f.read())
        
        if graph is None:
            self.graph = create_agent_graph()
        else:
            self.graph = graph

    def __call__(self, question: str, file_url: Optional[str] = None, file_ext: Optional[str] = None) -> str:
        """
        Call the agent with a question and optional file URL and extension.
        
        Args:
            question (str): The user's question.
            file_url (Optional[str]): The URL of the file to download.
            file_ext (Optional[str]): The file extension for the downloaded file.
        
        Returns:
            str: The agent's response.
        """
        if file_url and file_ext:
            question += f"\nREFERENCE FILE MUST BE RETRIEVED\nFile URL: {file_url}, File Extension: {file_ext}\nUSE A TOOL TO DOWNLOAD THIS FILE."
        state = {
            "messages": [self.system_message, HumanMessage(content=question)],
            "file_url": file_url,
            "file_ext": file_ext,
            "local_file_path": None,
            "final_answer": None
        }
        response = self.graph.invoke(state)
        for m in response["messages"]:
            m.pretty_print()
        return response["final_answer"] if response["final_answer"] else "No final answer generated."