Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,941 Bytes
27e0148 140f5d3 ad4860f 140f5d3 27e0148 ad4860f 140f5d3 27e0148 8e28628 10c1114 27e0148 140f5d3 27e0148 f05ca9d 27e0148 49e03fb 27e0148 49e03fb 27e0148 ad4860f 27e0148 8e28628 10c1114 8e28628 ad4860f 27e0148 140f5d3 27e0148 ad4860f a127a18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import duckdb
import gradio as gr
from dotenv import load_dotenv
from httpx import Client
from huggingface_hub import HfApi
#from llama_cpp import Llama
import pandas as pd
#from transformers import pipeline
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
API_URL = "https://m82etjwvhoptr3t5.us-east-1.aws.endpoints.huggingface.cloud"
headers = {
"Accept" : "application/json",
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json"
}
client = Client(headers=headers)
api = HfApi(token=HF_TOKEN)
# First approach: Use llama.cpp
#llama = Llama(model_path="DuckDB-NSQL-7B-v0.1-q8_0.gguf", n_ctx=2048)
#def query_local_model(text):
# pred = llama(text, temperature=0.1, max_tokens=500)
# return pred["choices"][0]["text"]
# Second approach: Use transformers -> Took too much time
#pipe = pipeline("text-generation", model="motherduckdb/DuckDB-NSQL-7B-v0.1")
#def query_local_model_transformers(text):
# pred = pipe(text, max_length=1000)
# return pred[0]["generated_text"]
def get_first_parquet(dataset: str):
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset}")
return resp.json()["parquet_files"][0]
def query_remote_model(text):
payload = {
"inputs": text,
"parameters": {}
}
response = client.post(API_URL, headers=headers, json=payload)
pred = response.json()
return pred[0]["generated_text"]
def text2sql(dataset_name, query_input):
print(f"start text2sql for {dataset_name}")
try:
first_parquet = get_first_parquet(dataset_name)
except Exception as error:
return {
query_output: "",
df:pd.DataFrame([{"error": f"❌ Could not get dataset schema. {error=}"}])
}
first_parquet_url = first_parquet["url"]
print(first_parquet_url)
con = duckdb.connect()
con.execute("INSTALL 'httpfs'; LOAD httpfs;")
# could get from Parquet instead?
con.execute(f"CREATE TABLE data as SELECT * FROM '{first_parquet_url}' LIMIT 1;")
result = con.sql("SELECT sql FROM duckdb_tables() where table_name ='data';").df()
ddl_create = result.iloc[0,0]
text = f"""### Instruction:
Your task is to generate valid duckdb SQL to answer the following question.
### Input:
Here is the database schema that the SQL query will run on:
{ddl_create}
### Question:
{query_input}
### Response (use duckdb shorthand if possible) replace table name with {first_parquet_url} in the generated sql query:
"""
try:
sql_output = query_remote_model(text)
except Exception as error:
return {
query_output: "",
df:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {error=}"}])
}
try:
query_result = con.sql(sql_output).df()
except Exception as error:
query_result = pd.DataFrame([{"error": f"❌ Could not execute SQL query {error=}"}])
finally:
con.close()
return {
query_output:sql_output,
df:query_result
}
with gr.Blocks() as demo:
gr.Markdown("# Generate SQL queries based on a given text for your dataset")
gr.Markdown("This space showcase how to generate a SQL query from a text and get the result.")
gr.Markdown("Tech stack: duckdb and DuckDB-NSQL-7B model")
dataset_name = gr.Textbox("sksayril/medicine-info", label="Dataset Name")
query_input = gr.Textbox("How many rows there are?", label="Ask something about your data")
btn = gr.Button("Generate SQL")
query_output = gr.Textbox(label="Output SQL", interactive= False)
df = gr.DataFrame(datatype="markdown")
btn.click(text2sql, inputs=[dataset_name, query_input], outputs=[query_output,df])
demo.launch(debug=True)
|