File size: 5,371 Bytes
66b6353 ddb520c 6cfa9f4 66b6353 ddb520c 66b6353 ddb520c 66b6353 ddb520c 66b6353 ddb520c 66b6353 ddb520c 6cfa9f4 66b6353 ddb520c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import requests
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
from sqlalchemy import (
create_engine,
MetaData,
inspect,
Table,
select,
distinct
)
from sqlalchemy.schema import CreateTable
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.engine import Engine
import re
def get_all_groq_model(api_key:str=None) -> list:
"""Uses Groq API to fetch all the available models."""
if api_key is None:
raise ValueError("API key is required")
url = "https://api.groq.com/openai/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
response = requests.get(url, headers=headers)
data = response.json()['data']
model_ids = [model['id'] for model in data]
return model_ids
def validate_api_key(api_key:str) -> bool:
"""Validates the Groq API key using the get_all_groq_model function."""
if len(api_key) == 0:
return False
try:
get_all_groq_model(api_key=api_key)
return True
except Exception as e:
return False
def validate_uri(uri:str) -> bool:
"""Validates the SQL Database URI using the SQLDatabase.from_uri function."""
try:
SQLDatabase.from_uri(uri)
return True
except Exception as e:
return False
def get_info(uri:str) -> dict[str, str] | None:
"""Gets the dialect name, accessible tables and table schemas using the SQLDatabase toolkit"""
db = SQLDatabase.from_uri(uri)
dialect = db.dialect
# List all the tables accessible to the user.
access_tables = ListSQLDatabaseTool(db=db).invoke("")
# List the table schemas of all the accessible tables.
tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}
def get_sample_rows(engine:Engine, table:Table, row_count: int = 3) -> str:
"""Gets the sample rows of a table using the SQLAlchemy engine"""
# build the select command
command = select(table).limit(row_count)
# save the columns in string format
columns_str = "\t".join([col.name for col in table.columns])
try:
# get the sample rows
with engine.connect() as connection:
sample_rows_result = connection.execute(command) # type: ignore
# shorten values in the sample rows
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
return (
f"{row_count} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def get_unique_values(engine:Engine, table:Table) -> str:
"""Gets the unique values of each column in a table using the SQLAlchemy engine"""
unique_values = {}
for column in table.c:
command = select(distinct(column))
try:
# get the sample rows
with engine.connect() as connection:
result = connection.execute(command) # type: ignore
# shorten values in the sample rows
unique_values[column.name] = [str(u) for u in result]
# save the sample rows in string format
# sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
output_str = f"Unique values of each column in {table.name}: \n"
for column, values in unique_values.items():
output_str += f"{column} has {len(values)} unique values: {" ".join(values[:20])}"
if len(values) > 20:
output_str += ", ...."
output_str += "\n"
return output_str
def get_info_sqlalchemy(uri:str) -> dict[str, str] | None:
"""Gets the dialect name, accessible tables and table schemas using the SQLAlchemy engine"""
engine = create_engine(uri)
# Get dialect name using inspector
inspector = inspect(engine)
dialect = inspector.dialect.name
# Metadata for tables and columns
m = MetaData()
m.reflect(engine)
tables = {}
for table in m.tables.values():
tables[table.name] = str(CreateTable(table).compile(engine)).rstrip()
tables[table.name] += "\n\n/*"
tables[table.name] += "\n" + get_sample_rows(engine, table)+"\n"
tables[table.name] += "\n" + get_unique_values(engine, table)+"\n"
tables[table.name] += "*/"
return {'sql_dialect': dialect, 'tables': ", ".join(tables.keys()), 'tables_schema': "\n\n".join(tables.values())}
def extract_code_blocks(text):
pattern = r"```(?:\w+)?\n(.*?)\n```"
matches = re.findall(pattern, text, re.DOTALL)
return matches
if __name__ == "__main__":
from dotenv import load_dotenv
import os
load_dotenv()
uri = os.getenv("POSTGRES_URI")
print(get_info_sqlalchemy(uri))
|