File size: 2,681 Bytes
16601c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.memory import ChatMessageHistory
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from table_details import create_table_chain
from prompts import create_prompts

def get_db_uri(credentials):
    return f"postgresql+psycopg2://{credentials['user']}:{credentials['password']}@{credentials['host']}:{credentials['port']}/{credentials['database']}"

@st.cache_resource
def get_chain(_db_uri, api_key):
    """Create the langchain with the provided credentials"""
    try:
        db = SQLDatabase.from_uri(_db_uri)
        llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key)
        
        # Get the table chain and prompts
        table_chain = create_table_chain(api_key)
        final_prompt, answer_prompt = create_prompts(api_key)
        
        generate_query = create_sql_query_chain(llm, db, final_prompt)
        execute_query = QuerySQLDataBaseTool(db=db)
        rephrase_answer = answer_prompt | llm | StrOutputParser()
        
        chain = (
            RunnablePassthrough.assign(table_names_to_use=table_chain) |
            RunnablePassthrough.assign(query=generate_query).assign(
                result=itemgetter("query") | execute_query
            ) | rephrase_answer
        )
        
        return chain
    except Exception as e:
        st.error(f"Error creating chain: {str(e)}")
        return None

def create_history(messages):
    history = ChatMessageHistory()
    for message in messages:
        if message["role"] == "user":
            history.add_user_message(message["content"])
        else:
            history.add_ai_message(message["content"])
    return history

def invoke_chain(question, messages, db_credentials, api_key):
    try:
        db_uri = get_db_uri(db_credentials)
        chain = get_chain(db_uri, api_key)
        if chain is None:
            return "Sorry, I couldn't connect to the database. Please check your credentials."
        
        history = create_history(messages)
        response = chain.invoke({
            "question": question,
            "top_k": 100,
            "messages": history.messages
        })
        
        history.add_user_message(question)
        history.add_ai_message(response)
        return response
        
    except Exception as e:
        return f"An error occurred: {str(e)}"