Advance-NL-to-SQL / table_details.py
sango07's picture
Upload 6 files
16601c8 verified
import pandas as pd
import streamlit as st
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from typing import List
class Table(BaseModel):
"""Table in SQL database."""
name: str = Field(description="Name of table in SQL database.")
def get_tables(tables: List[Table]) -> List[str]:
return [table.name for table in tables]
@st.cache_data
def get_table_details():
try:
table_description = pd.read_excel("database_table_descriptions.xlsx")
table_details = ""
for index, row in table_description.iterrows():
table_details += f"Table Name:{row['Table']}\nTable Description:{row['Description']}\n\n"
return table_details
except Exception as e:
st.error(f"Error reading table descriptions: {str(e)}")
return ""
def create_table_chain(api_key):
table_details = get_table_details()
table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:
{table_details}
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key)
return (
{"input": itemgetter("question")} |
create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) |
get_tables
)