File size: 4,843 Bytes
27e0148
 
 
 
 
 
140f5d3
ad4860f
140f5d3
27e0148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad4860f
140f5d3
 
 
 
 
 
 
 
 
 
 
 
 
27e0148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e28628
10c1114
32490fd
 
10c1114
 
 
 
27e0148
32490fd
27e0148
 
140f5d3
27e0148
 
 
f05ca9d
27e0148
bbc0046
27e0148
 
 
 
49e03fb
27e0148
 
 
bbc0046
27e0148
8e28628
 
 
10c1114
32490fd
 
10c1114
 
 
457910c
d34f941
457910c
ad4860f
 
 
 
 
 
 
32490fd
 
ad4860f
 
 
27e0148
 
 
140f5d3
 
 
a857a43
 
d34f941
 
 
 
 
 
 
27e0148
32490fd
 
27e0148
ad4860f
32490fd
a127a18
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import duckdb
import gradio as gr
from dotenv import load_dotenv
from httpx import Client
from huggingface_hub import HfApi
#from llama_cpp import Llama
import pandas as pd
#from transformers import pipeline

load_dotenv()

HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"


BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
API_URL = "https://m82etjwvhoptr3t5.us-east-1.aws.endpoints.huggingface.cloud"
headers = {
	"Accept" : "application/json",
	"Authorization": f"Bearer {HF_TOKEN}",
	"Content-Type": "application/json" 
}

client = Client(headers=headers)
api = HfApi(token=HF_TOKEN)

# First approach: Use llama.cpp
#llama = Llama(model_path="DuckDB-NSQL-7B-v0.1-q8_0.gguf", n_ctx=2048)
#def query_local_model(text):
#    pred = llama(text, temperature=0.1, max_tokens=500)
#    return pred["choices"][0]["text"]


# Second approach: Use transformers -> Took too much time
#pipe = pipeline("text-generation", model="motherduckdb/DuckDB-NSQL-7B-v0.1")
#def query_local_model_transformers(text):
#    pred = pipe(text, max_length=1000)
#    return pred[0]["generated_text"]


def get_first_parquet(dataset: str):
    resp = client.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset}")
    return resp.json()["parquet_files"][0]


def query_remote_model(text):
    payload = {
        "inputs": text,
        "parameters": {}
    }
    response = client.post(API_URL, headers=headers, json=payload)
    pred = response.json()
    return pred[0]["generated_text"]


def text2sql(dataset_name, query_input):
    print(f"start text2sql for {dataset_name}")
    try:
        first_parquet = get_first_parquet(dataset_name)
    except Exception as error:
        return {
            schema_output: "",
            prompt_output: "",
            query_output: "",
            df:pd.DataFrame([{"error": f"❌ Could not get dataset schema. {error=}"}])
        }

    first_parquet_url = first_parquet["url"]
    print(f"getting schema from {first_parquet_url}")
    con = duckdb.connect()
    con.execute("INSTALL 'httpfs'; LOAD httpfs;")
    # could get from Parquet instead?
    con.execute(f"CREATE TABLE data as SELECT * FROM '{first_parquet_url}' LIMIT 1;")
    result = con.sql("SELECT sql FROM duckdb_tables() where table_name ='data';").df()
    ddl_create = result.iloc[0,0]
    
    text = f"""### Instruction:
    Your task is to generate valid duckdb SQL to answer the following question. The SQL output should replace all table names with parquet file {first_parquet_url}

    ### Input:
    Here is the database schema that the SQL query will run on:
    {ddl_create}
    
    ### Question:
    {query_input}

    ### Response (use duckdb shorthand if possible) replace all table names with {first_parquet_url} in the generated sql query:
    """
    try:
        sql_output =  query_remote_model(text)
    except Exception as error:
        return {
            schema_output: ddl_create,
            prompt_output: text,
            query_output: "",
            df:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {error=}"}])
        }

    # Should be replaced by the prompt but not working
    sql_output = sql_output.replace("data", f"'{first_parquet_url}'")
    try:
        query_result = con.sql(sql_output).df()
    except Exception as error:
        query_result = pd.DataFrame([{"error": f"❌ Could not execute SQL query {error=}"}])
    finally:
        con.close()
    return {
        schema_output: ddl_create,
        prompt_output: text,
        query_output:sql_output,
        df:query_result
    }


with gr.Blocks() as demo:
    gr.Markdown("# Generate SQL queries based on a given text for your dataset")
    gr.Markdown("This space showcase how to generate a SQL query from a text and get the result.")
    gr.Markdown("Tech stack: duckdb and DuckDB-NSQL-7B model")
    dataset_name = gr.Textbox("jamescalam/world-cities-geo", label="Dataset Name")
    query_input = gr.Textbox("Which cities are part of Albania country?", label="Ask something about your data")
    examples = [
        ["Cities from Albania country"],
        ["The continent with the most number of countries"],
        ["Cities that start with 'A'"],
        ["Cities by region"],
    ]
    gr.examples(examples=examples, output=query_input)
    btn = gr.Button("Generate SQL")
    schema_output = gr.Textbox(label="Parquet Schema as CREATE DDL", interactive= False)
    prompt_output = gr.Textbox(label="Generated prompt", interactive= False)
    query_output = gr.Textbox(label="Output SQL", interactive= False)    
    df = gr.DataFrame(datatype="markdown")
    btn.click(text2sql, inputs=[dataset_name, query_input], outputs=[schema_output, prompt_output, query_output,df])
demo.launch(debug=True)