File size: 1,534 Bytes
16601c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pandas as pd
import streamlit as st
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from typing import List

class Table(BaseModel):
    """Table in SQL database."""
    name: str = Field(description="Name of table in SQL database.")

def get_tables(tables: List[Table]) -> List[str]:
    return [table.name for table in tables]

@st.cache_data
def get_table_details():
    try:
        table_description = pd.read_excel("database_table_descriptions.xlsx")
        table_details = ""
        for index, row in table_description.iterrows():
            table_details += f"Table Name:{row['Table']}\nTable Description:{row['Description']}\n\n"
        return table_details
    except Exception as e:
        st.error(f"Error reading table descriptions: {str(e)}")
        return ""

def create_table_chain(api_key):
    table_details = get_table_details()
    table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
    The tables are:

    {table_details}

    Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

    llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key)
    return (
        {"input": itemgetter("question")} | 
        create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | 
        get_tables
    )