File size: 5,371 Bytes
66b6353
 
 
ddb520c
 
 
 
 
 
 
 
 
 
 
6cfa9f4
66b6353
 
ddb520c
66b6353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddb520c
66b6353
 
 
 
 
 
 
 
 
ddb520c
66b6353
 
 
 
 
 
 
ddb520c
66b6353
 
 
 
 
 
 
 
ddb520c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cfa9f4
 
 
 
 
66b6353
ddb520c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import requests
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
from sqlalchemy import (
    create_engine,
    MetaData,
    inspect,
    Table,
    select,
    distinct
)
from sqlalchemy.schema import CreateTable
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.engine import Engine
import re

def get_all_groq_model(api_key:str=None) -> list:
    """Uses Groq API to fetch all the available models."""
    if api_key is None:
        raise ValueError("API key is required")
    url = "https://api.groq.com/openai/v1/models"

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }

    response = requests.get(url, headers=headers)

    data = response.json()['data']
    model_ids = [model['id'] for model in data]

    return model_ids

def validate_api_key(api_key:str) -> bool:
    """Validates the Groq API key using the get_all_groq_model function."""
    if len(api_key) == 0:
        return False
    try:
        get_all_groq_model(api_key=api_key)
        return True
    except Exception as e:
        return False

def validate_uri(uri:str) -> bool:
    """Validates the SQL Database URI using the SQLDatabase.from_uri function."""
    try:
        SQLDatabase.from_uri(uri)
        return True
    except Exception as e:
        return False

def get_info(uri:str) -> dict[str, str] | None:
    """Gets the dialect name, accessible tables and table schemas using the SQLDatabase toolkit"""
    db = SQLDatabase.from_uri(uri)
    dialect = db.dialect
    # List all the tables accessible to the user.
    access_tables = ListSQLDatabaseTool(db=db).invoke("")
    # List the table schemas of all the accessible tables.
    tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
    return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}

def get_sample_rows(engine:Engine, table:Table, row_count: int = 3) -> str:
    """Gets the sample rows of a table using the SQLAlchemy engine"""
    # build the select command
    command = select(table).limit(row_count)

    # save the columns in string format
    columns_str = "\t".join([col.name for col in table.columns])

    try:
        # get the sample rows
        with engine.connect() as connection:
            sample_rows_result = connection.execute(command)  # type: ignore
            # shorten values in the sample rows
            sample_rows = list(
                map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
            )

        # save the sample rows in string format
        sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])

    # in some dialects when there are no rows in the table a
    # 'ProgrammingError' is returned
    except ProgrammingError:
        sample_rows_str = ""

    return (
        f"{row_count} rows from {table.name} table:\n"
        f"{columns_str}\n"
        f"{sample_rows_str}"
    )

def get_unique_values(engine:Engine, table:Table) -> str:
    """Gets the unique values of each column in a table using the SQLAlchemy engine"""
    unique_values = {}
    for column in table.c:
        command = select(distinct(column))

        try:
            # get the sample rows
            with engine.connect() as connection:
                result = connection.execute(command)  # type: ignore
                # shorten values in the sample rows
                unique_values[column.name] = [str(u) for u in result]

            # save the sample rows in string format
            # sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
            # in some dialects when there are no rows in the table a
            # 'ProgrammingError' is returned
        except ProgrammingError:
            sample_rows_str = ""

    output_str = f"Unique values of each column in {table.name}: \n"
    for column, values in unique_values.items():
        output_str += f"{column} has {len(values)} unique values: {" ".join(values[:20])}"
        if len(values) > 20:
            output_str += ", ...."
        output_str += "\n"

    return output_str

def get_info_sqlalchemy(uri:str) -> dict[str, str] | None:
    """Gets the dialect name, accessible tables and table schemas using the SQLAlchemy engine"""
    engine = create_engine(uri)
    # Get dialect name using inspector
    inspector = inspect(engine)
    dialect = inspector.dialect.name
    # Metadata for tables and columns
    m = MetaData()
    m.reflect(engine)

    tables = {}
    for table in m.tables.values():
        tables[table.name] = str(CreateTable(table).compile(engine)).rstrip()
        tables[table.name] += "\n\n/*"
        tables[table.name] += "\n" + get_sample_rows(engine, table)+"\n"
        tables[table.name] += "\n" + get_unique_values(engine, table)+"\n"
        tables[table.name] += "*/"

    return {'sql_dialect': dialect, 'tables': ", ".join(tables.keys()), 'tables_schema': "\n\n".join(tables.values())}

def extract_code_blocks(text):
    pattern = r"```(?:\w+)?\n(.*?)\n```"
    matches = re.findall(pattern, text, re.DOTALL)
    return matches

if __name__ == "__main__":
    from dotenv import load_dotenv
    import os
    load_dotenv()

    uri = os.getenv("POSTGRES_URI")
    print(get_info_sqlalchemy(uri))