Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Sequence, List | |
| import streamlit as st | |
| from langchain.agents import AgentExecutor | |
| from langchain.schema.language_model import BaseLanguageModel | |
| from langchain.tools import BaseTool | |
| from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter | |
| from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT | |
| from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS | |
| from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS | |
| from logger import logger | |
| try: | |
| from sqlalchemy.orm import declarative_base | |
| except ImportError: | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.prompts.chat import MessagesPlaceholder | |
| from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory | |
| from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent | |
| from langchain.schema.messages import SystemMessage | |
| from langchain.memory import SQLChatMessageHistory | |
| def create_agent_executor( | |
| agent_name: str, | |
| session_id: str, | |
| llm: BaseLanguageModel, | |
| tools: Sequence[BaseTool], | |
| system_prompt: str, | |
| **kwargs | |
| ) -> AgentExecutor: | |
| agent_name = agent_name.replace(" ", "_") | |
| conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}' | |
| chat_memory = SQLChatMessageHistory( | |
| session_id, | |
| connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https', | |
| custom_message_converter=DefaultClickhouseMessageConverter(agent_name)) | |
| memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory) | |
| prompt = OpenAIFunctionsAgent.create_prompt( | |
| system_message=SystemMessage(content=system_prompt), | |
| extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], | |
| ) | |
| agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) | |
| return AgentExecutor( | |
| agent=agent, | |
| tools=tools, | |
| memory=memory, | |
| verbose=True, | |
| return_intermediate_steps=True, | |
| **kwargs | |
| ) | |
| def build_agents( | |
| session_id: str, | |
| tool_names: List[str], | |
| model: str = "gpt-3.5-turbo-0125", | |
| temperature: float = 0.6, | |
| system_prompt: str = DEFAULT_SYSTEM_PROMPT | |
| ): | |
| chat_llm = ChatOpenAI( | |
| model_name=model, | |
| temperature=temperature, | |
| base_url=GLOBAL_CONFIG.openai_api_base, | |
| api_key=GLOBAL_CONFIG.openai_api_key, | |
| streaming=True | |
| ) | |
| tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS)) | |
| selected_tools = [tools[k] for k in tool_names] | |
| logger.info(f"create agent, use tools: {selected_tools}") | |
| agent = create_agent_executor( | |
| agent_name="chat_memory", | |
| session_id=session_id, | |
| llm=chat_llm, | |
| tools=selected_tools, | |
| system_prompt=system_prompt | |
| ) | |
| return agent | |