File size: 2,395 Bytes
fb95c43
 
 
 
 
 
 
 
 
5a59a40
 
fba76a5
5a59a40
153e9c1
 
fb95c43
 
 
ddfe3b8
fb95c43
 
 
 
 
 
 
ddfe3b8
fb95c43
5a59a40
 
fb95c43
ddfe3b8
fb95c43
 
 
 
 
 
 
 
153e9c1
 
 
 
fb95c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a59a40
 
fb95c43
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# HF libraries
from langchain_huggingface import HuggingFaceEndpoint
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
# Import things that are needed generically
from langchain.tools.render import render_text_description
import os
from dotenv import load_dotenv
# local cache
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache # sqlite
#from langchain.cache import InMemoryCache # in memory cache
from rag_app.structured_tools.structured_tools import (
    google_search, knowledgeBase_search
)

from langchain.prompts import PromptTemplate
from rag_app.templates.react_json_with_memory_ger import template_system
# from innovation_pathfinder_ai.utils import logger
# logger = logger.get_console_logger("hf_mixtral_agent")

config = load_dotenv(".env")
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
GOOGLE_CSE_ID = os.getenv('GOOGLE_CSE_ID')
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
LLM_MODEL = os.getenv('LLM_MODEL')

set_llm_cache(SQLiteCache(database_path=".cache.db"))

# Load the model from the Hugging Face Hub
llm = HuggingFaceEndpoint(repo_id=LLM_MODEL, 
                          temperature=0.1, 
                          max_new_tokens=1024,
                          repetition_penalty=1.2,
                          return_full_text=False
    )


tools = [
    knowledgeBase_search,
    google_search,
    #web_research,
    #ask_user
    ]

prompt = PromptTemplate.from_template(
    template=template_system
)
prompt = prompt.partial(
    tools=render_text_description(tools),
    tool_names=", ".join([t.name for t in tools]),
)


# define the agent
chat_model_with_stop = llm.bind(stop=["\nObservation"])
agent = (
    {
        "input": lambda x: x["input"],
        "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
        "chat_history": lambda x: x["chat_history"],
    }
    | prompt
    | chat_model_with_stop
    | ReActJsonSingleInputOutputParser()
)

# instantiate AgentExecutor
agent_executor = AgentExecutor(
    agent=agent, 
    tools=tools, 
    verbose=True,
    max_iterations=20,       # cap number of iterations
    max_execution_time=90,  # timout at 60 sec
    return_intermediate_steps=True,
    handle_parsing_errors=True,
    )