Mustehson's picture
Update app.py
99f3938 verified
raw
history blame
6.69 kB
import os
import torch
import duckdb
import spaces
import gradio as gr
import pandas as pd
from langchain_huggingface.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from langsmith import traceable
from langchain import hub
# Height of the Tabs Text Area
TAB_LINES = 8
#----------CONNECT TO DATABASE----------
md_token = os.getenv('MD_TOKEN')
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
#---------------------------------------
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
#---------------------------------------
#-------LOAD HUGGINGFACE PIPELINE-------
tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type= "nf4")
model = AutoModelForCausalLM.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1", quantization_config=quantization_config,
device_map="auto", torch_dtype=torch.bfloat16)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False)
hf = HuggingFacePipeline(pipeline=pipe)
#---------------------------------------
#-----LOAD PROMPT FROM LANCHAIN HUB-----
prompt = hub.pull("sql-agent-prompt")
#---------------------------------------
#--------------ALL UTILS----------------
# Get Databases
def get_schemas():
schemas = conn.execute("""
SELECT DISTINCT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
""").fetchall()
return [item[0] for item in schemas]
# Get Tables
def get_tables(schema_name):
tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall()
return [table[0] for table in tables]
# Update Tables
def update_tables(schema_name):
tables = get_tables(schema_name)
return gr.update(choices=tables)
# Get Schema
def get_table_schema(table):
result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
ddl_create = result.iloc[0,0]
parent_database = result.iloc[0,1]
schema_name = result.iloc[0,2]
full_path = f"{parent_database}.{schema_name}.{table}"
if schema_name != "main":
old_path = f"{schema_name}.{table}"
else:
old_path = table
ddl_create = ddl_create.replace(old_path, full_path)
return ddl_create
# Get Prompt
def get_prompt(schema, query_input):
return prompt.format(schema=schema, query_input=query_input)
@spaces.GPU(duration=60)
@traceable()
def generate_sql(prompt):
result = hf.invoke(prompt)
return result.strip()
#---------------------------------------
# Generate SQL
def text2sql(table, query_input):
if table is None:
return {
table_schema: "",
input_prompt: "",
generated_query: "",
result_output:pd.DataFrame([{"error": "❌ Please Select Table, Schema.}"}])
}
schema = get_table_schema(table)
print(f'Schema Generated...')
prompt = get_prompt(schema, query_input)
print(f'Prompt Generated...')
try:
print(f'Generating SQL... {model.device}')
result = generate_sql(prompt)
print('SQL Generated...')
except Exception as e:
return {
table_schema: schema,
input_prompt: prompt,
generated_query: "",
result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
}
try:
query_result = conn.sql(result).df()
except Exception as e:
return {
table_schema: schema,
input_prompt: prompt,
generated_query: result,
result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
}
return {
table_schema: schema,
input_prompt: prompt,
generated_query: result,
result_output:query_result
}
# Custom CSS styling
custom_css = """
.gradio-container {
background-color: #f0f4f8;
}
.logo {
max-width: 200px;
margin: 20px auto;
display: block;
}
.gr-button {
background-color: #4a90e2 !important;
}
.gr-button:hover {
background-color: #3a7bc8 !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
gr.Markdown("""
<div style='text-align: center;'>
<strong style='font-size: 36px;'>Datajoi SQL Agent</strong>
<br>
<span style='font-size: 20px;'>Generate and Run SQL queries based on a given text for the dataset.</span>
</div>
""")
with gr.Row():
with gr.Column(scale=1, variant='panel'):
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
with gr.Column(scale=2):
query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...")
with gr.Row():
with gr.Column(scale=7):
pass
with gr.Column(scale=1):
generate_query_button = gr.Button("Run Query", variant="primary")
with gr.Tabs():
with gr.Tab("Result"):
result_output = gr.DataFrame(label="Query Results", value=[], interactive=False)
with gr.Tab("SQL Query"):
generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
with gr.Tab("Prompt"):
input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False)
with gr.Tab("Schema"):
table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
generate_query_button.click(text2sql, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output])
if __name__ == "__main__":
demo.launch()