Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,076 Bytes
27e0148 ad4860f d61f780 27e0148 ad4860f d61f780 27e0148 d61f780 ad4860f 27e0148 d61f780 638b67a d61f780 54b1e02 d61f780 27e0148 ad4860f 27e0148 f05ca9d 27e0148 e2721aa 27e0148 ad4860f 27e0148 e2721aa 27e0148 d61f780 ad4860f 27e0148 ad4860f a127a18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
import duckdb
import gradio as gr
from dotenv import load_dotenv
from httpx import Client
from huggingface_hub import HfApi
from huggingface_hub.utils import logging
from llama_cpp import Llama
import pandas as pd
from transformers import pipeline
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
API_URL = "https://m82etjwvhoptr3t5.us-east-1.aws.endpoints.huggingface.cloud"
headers = {
"Accept" : "application/json",
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json"
}
logger = logging.get_logger(__name__)
client = Client(headers=headers)
api = HfApi(token=HF_TOKEN)
print("About to load DuckDB-NSQL-7B model")
"""
llama = Llama(
model_path="DuckDB-NSQL-7B-v0.1-q8_0.gguf",
n_ctx=2048,
)
"""
pipe = pipeline("text-generation", model="motherduckdb/DuckDB-NSQL-7B-v0.1")
print("DuckDB-NSQL-7B model has been loaded")
def get_first_parquet(dataset: str):
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset}")
return resp.json()["parquet_files"][0]
def query_remote_model(text):
payload = {
"inputs": text,
"parameters": {}
}
response = client.post(API_URL, headers=headers, json=payload)
pred = response.json()
return pred[0]["generated_text"]
def query_local_model_transformers(text):
pred = pipe(text, max_length=1000)
print(type(pred))
print(pred)
return pred[0]["generated_text"]
def query_local_model(text):
pred = llama(text, temperature=0.1, max_tokens=500)
return pred["choices"][0]["text"]
def text2sql(dataset_name, query_input):
print(f"start text2sql for {dataset_name}")
try:
first_parquet = get_first_parquet(dataset_name)
except Exception as e:
return f"❌ Dataset does not exist or is not supported {e}"
first_parquet_url = first_parquet["url"]
print(first_parquet_url)
con = duckdb.connect()
con.execute("INSTALL 'httpfs'; LOAD httpfs;")
# could get from parquet instead?
con.execute(f"CREATE TABLE data as SELECT * FROM '{first_parquet_url}' LIMIT 1;")
result = con.sql("SELECT sql FROM duckdb_tables() where table_name ='data';").df()
ddl_create = result.iloc[0,0]
text = f"""### Instruction:
Your task is to generate valid duckdb SQL to answer the following question. Only the SQL query should be returned
### Input:
Here is the database schema that the SQL query will run on:
{ddl_create}
### Question:
{query_input}
### Response (use duckdb shorthand if possible) replace table name with {first_parquet_url} in the generated sql query:
"""
text = f"""Given the following SQL table, your job is to write queries given a user’s request.
{ddl_create}
Write a SQL query that computes the following request: {query_input}.
"""
print(text)
# sql_output = query_remote_model(text)
sql_output = query_local_model_transformers(text)
try:
query_result = con.sql(sql_output).df()
except Exception as error:
query_result = pd.DataFrame([{"error": f"❌ Could not execute SQL query {error=}"}])
finally:
con.close()
return {
query_output:sql_output,
df:query_result
}
with gr.Blocks() as demo:
gr.Markdown("# Talk to your dataset")
gr.Markdown("This space shows how to talk to your datasets: Get a brief description, create SQL queries, and get results.")
gr.Markdown("Generate SQL queries'")
dataset_name = gr.Textbox("sksayril/medicine-info", label="Dataset Name")
query_input = gr.Textbox("How many rows there are?", label="Ask something about your data")
btn = gr.Button("Generate SQL")
query_output = gr.Textbox(label="Output SQL", interactive= False)
df = gr.DataFrame(datatype="markdown")
btn.click(text2sql, inputs=[dataset_name, query_input], outputs=[query_output,df])
demo.launch(debug=True)
|