Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from dotenv import load_dotenv | |
| from langchain.memory import ConversationSummaryMemory | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_community.utilities import SQLDatabase | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_openai import ChatOpenAI | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain.agents import create_tool_calling_agent, AgentExecutor, Tool | |
| from langchain_community.vectorstores import FAISS | |
| from config.settings import Settings | |
| # Load environment variables | |
| load_dotenv() | |
| open_api_key_token = os.getenv('OPENAI_API_KEY') | |
| #db_uri = os.getenv('POST_DB_URI') | |
| db_uri = Settings.DB_URI | |
| class ChatAgentService: | |
| def __init__(self): | |
| # Database setup | |
| self.db = SQLDatabase.from_uri(db_uri) | |
| self.llm = ChatOpenAI(model="gpt-3.5-turbo-0125", api_key=open_api_key_token,max_tokens=150,temperature=0.2) | |
| self.memory = ConversationSummaryMemory(llm=self.llm, return_messages=True) | |
| # Tools setup | |
| self.tools = [ | |
| Tool( | |
| name="DatabaseQuery", | |
| func=self.database_tool, | |
| description="Queries the SQL database using dynamically generated SQL queries based on user questions. Aimed to retrieve structured data like counts, specific records, or summaries from predefined schemas.", | |
| tool_choice="required" | |
| ), | |
| Tool( | |
| name="DocumentData", | |
| func=self.document_data_tool, | |
| description="Searches through indexed documents to find relevant information based on user queries. Handles unstructured data from various document formats like PDF, DOCX, or TXT files.", | |
| tool_choice="required" | |
| ), | |
| ] | |
| # Agent setup | |
| prompt_template = self.setup_prompt() | |
| self.agent = create_tool_calling_agent(self.llm.bind(memory=self.memory), self.tools, prompt_template) | |
| self.agent_executor = AgentExecutor(agent=self.agent, tools=self.tools, memory=self.memory, verbose=True) | |
| def setup_prompt(self): | |
| prompt_template = f""" | |
| You are an assistant that helps with database queries and document retrieval. | |
| Please base your responses strictly on available data and avoid assumptions. | |
| If the question pertains to numerical data or structured queries, use the DatabaseQuery tool. | |
| If the question relates to content within various documents, use the DocumentData tool. | |
| Question: {{input}} | |
| {{agent_scratchpad}} | |
| """ | |
| return ChatPromptTemplate.from_template(prompt_template) | |
| def database_tool(self, question): | |
| sql_query = self.generate_sql_query(question) | |
| return self.run_query(sql_query) | |
| def get_schema(self,_): | |
| # print(self.db.get_table_info()) | |
| return self.db.get_table_info() | |
| def generate_sql_query(self, question): | |
| schema = self.get_schema(None) # Get the schema using the function | |
| template_query_generation = """Generate a SQL query to answer the user's question based on the available database schema. | |
| {schema} | |
| Question: {question} | |
| SQL Query:""" | |
| prompt_query_generation = ChatPromptTemplate.from_template(template_query_generation) | |
| # Correctly setting up the initial data dictionary for the chain | |
| input_data = {'question': question} | |
| # Setup the chain correctly | |
| sql_chain = (RunnablePassthrough.assign(schema=self.get_schema) | |
| | prompt_query_generation | |
| | self.llm.bind(stop="\nSQL Result:") | |
| | StrOutputParser()) | |
| # Make sure to invoke with an empty dictionary if all needed data is already assigned | |
| return sql_chain.invoke(input_data) | |
| def run_query(self, query): | |
| try: | |
| logging.info(f"Executing SQL query: {query}") | |
| result = self.db.run(query) | |
| logging.info(f"Query successful: {result}") | |
| return result | |
| except Exception as e: | |
| logging.error(f"Error executing query: {query}, Error: {str(e)}") | |
| return None | |
| def document_data_tool(self, query): | |
| try: | |
| logging.info(f"Searching documents for query: {query}") | |
| embeddings = OpenAIEmbeddings(api_key=open_api_key_token) | |
| index_paths = self.find_index_for_document(query) | |
| responses = [] | |
| for index_path in index_paths: | |
| vector_store = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True) | |
| response = self.query_vector_store(vector_store, query) | |
| responses.append(response) | |
| logging.info(f"Document search results: {responses}") | |
| return "\n".join(responses) | |
| except Exception as e: | |
| logging.error(f"Error in document data tool for query: {query}, Error: {str(e)}") | |
| return "Error processing document query." | |
| def find_index_for_document(self, query): | |
| base_path = os.getenv('VECTOR_DB_PATH') | |
| # document_hint = self.extract_document_hint(query) | |
| index_paths = [] | |
| for root, dirs, files in os.walk(base_path): | |
| for dir in dirs: | |
| if 'index.faiss' in os.listdir(os.path.join(root, dir)): | |
| index_paths.append(os.path.join(root, dir, '')) | |
| return index_paths | |
| def query_vector_store(self, vector_store, query): | |
| docs = vector_store.similarity_search(query) | |
| return '\n\n'.join([doc.page_content for doc in docs]) | |
| def answer_question(self, user_question): | |
| try: | |
| logging.info(f"Received question: {user_question}") | |
| response = self.agent_executor.invoke({"input": user_question}) | |
| output_response = response.get("output", "No valid response generated.") | |
| logging.info(f"Response generated: {output_response}") | |
| return output_response | |
| except Exception as e: | |
| logging.error(f"Error processing question: {user_question}, Error: {str(e)}") | |
| return f"An error occurred: {str(e)}" | |