augray
Working; needs refinements
57f2e80
raw
history blame
8.96 kB
import json
import logging
import os
import urllib.parse
from typing import Any
import gradio as gr
import requests
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from huggingface_hub.repocard import CardData, RepoCard
logger = logging.getLogger(__name__)
example = HuggingfaceHubSearch().example_value()
SYSTEM_PROMPT_TEMPLATE = (
"You are a SQL query expert assistant that returns a DuckDB SQL queries "
"based on the user's natural language query and dataset features. "
"You might need to use DuckDB functions for lists and aggregations, "
"given the features. Only return the SQL query, no other text. The "
"user may ask you to make various adjustments to the query. Every "
"time your response should only include the refined SQL query and "
"nothing else.\n\n"
"The table being queried is named: {table_name}.\n\n"
"# Features\n"
"{features}"
)
def get_iframe(hub_repo_id, sql_query=None):
if not hub_repo_id:
raise ValueError("Hub repo id is required")
if sql_query:
sql_query = urllib.parse.quote(sql_query)
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer?sql_console=true&sql={sql_query}"
else:
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
iframe = f"""
<iframe
src="{url}"
frameborder="0"
width="100%"
height="800px"
></iframe>
"""
return iframe
def get_table_info(hub_repo_id):
url: str = f"https://datasets-server.huggingface.co/info?dataset={hub_repo_id}"
response = requests.get(url)
try:
data = response.json()
data = data.get("dataset_info")
return json.dumps(data)
except Exception as e:
gr.Error(f"Error getting column info: {e}")
def get_table_name(
config: str | None,
split: str | None,
config_choices: list[str],
split_choices: list[str],
):
if len(config_choices) > 0 and config is None:
config = config_choices[0]
if len(split_choices) > 0 and split is None:
split = split_choices[0]
if len(config_choices) > 1 and len(split_choices) > 1:
base_name = f"{config}_{split}"
elif len(config_choices) >= 1 and len(split_choices) <= 1:
base_name = config
else:
base_name = split
def replace_char(c):
if c.isalnum():
return c
if c in ["-", "_", "/"]:
return "_"
return ""
table_name = "".join(replace_char(c) for c in base_name)
if table_name[0].isdigit():
table_name = f"_{table_name}"
return table_name.lower()
def get_system_prompt(
card_data: dict[str, Any],
config: str | None,
split: str | None,
):
config_choices = get_config_choices(card_data)
split_choices = get_split_choices(card_data)
table_name = get_table_name(config, split, config_choices, split_choices)
features = card_data[config]["features"]
return SYSTEM_PROMPT_TEMPLATE.format(
table_name=table_name,
features=features,
)
def get_config_choices(card_data: dict[str, Any]) -> list[str]:
return list(card_data.keys())
def get_split_choices(card_data: dict[str, Any]) -> list[str]:
splits = set()
for config in card_data.values():
splits.update(config.get("splits", {}).keys())
return list(splits)
def query_dataset(hub_repo_id, card_data, query, config, split, history):
if card_data is None or len(card_data) == 0:
return "", get_iframe(hub_repo_id), []
card_data = json.loads(card_data)
system_prompt = get_system_prompt(card_data, config, split)
messages = [{"role": "system", "content": system_prompt}]
for turn in history:
user, assistant = turn
messages.append(
{
"role": "user",
"content": user,
}
)
messages.append(
{
"role": "assistant",
"content": assistant,
}
)
messages.append(
{
"role": "user",
"content": query,
}
)
api_key = os.environ["API_KEY_TOGETHER_AI"].strip()
response = requests.post(
"https://api.together.xyz/v1/chat/completions",
json=dict(
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
messages=messages,
max_tokens=1000,
),
headers={"Authorization": f"Bearer {api_key}"},
)
if response.status_code != 200:
logger.warning(response.text)
try:
response.raise_for_status()
except Exception as e:
gr.Error(f"Could not query LLM for suggestion: {e}")
response_dict = response.json()
duck_query = response_dict["choices"][0]["message"]["content"]
duck_query = _sanitize_duck_query(duck_query)
history.append((query, duck_query))
return duck_query, get_iframe(hub_repo_id, duck_query), history
def _sanitize_duck_query(duck_query: str) -> str:
# Sometimes the LLM wraps the query like this:
# ```sql
# select * from x;
# ```
# This removes that wrapping if present.
if "```" not in duck_query:
return duck_query
start_idx = duck_query.index("```") + len("```")
end_idx = duck_query.rindex("```")
duck_query = duck_query[start_idx:end_idx]
if duck_query.startswith("sql\n"):
duck_query = duck_query.replace("sql\n", "", 1)
return duck_query
with gr.Blocks() as demo:
gr.Markdown("""# πŸ₯ πŸ¦™ πŸ€— Text To SQL Hub Datasets πŸ€— πŸ¦™ πŸ₯
This is a basic text to SQL tool that allows you to query datasets on Huggingface Hub.
It is built with [DuckDB](https://duckdb.org/), [Huggingface's Inference API](https://huggingface.co/docs/api-inference/index), and [LLama 3.1 70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct).
Also, it uses the [dataset-server API](https://redocly.github.io/redoc/?url=https://datasets-server.huggingface.co/openapi.json#operation/isValidDataset).
""")
with gr.Row():
search_in = HuggingfaceHubSearch(
label="Search Huggingface Hub",
placeholder="Search for models on Huggingface",
search_type="dataset",
sumbit_on_select=True,
)
with gr.Row():
show_btn = gr.Button("Show Dataset")
with gr.Row():
sql_out = gr.Code(
label="DuckDB SQL Query",
interactive=True,
language="sql",
lines=1,
visible=False,
)
with gr.Row():
card_data = gr.Code(label="Card data", language="json", visible=False)
@gr.render(inputs=[card_data])
def show_config_split_choices(data):
try:
data = json.loads(data.strip())
config_choices = get_config_choices(data)
split_choices = get_split_choices(data)
except Exception:
config_choices = []
split_choices = []
initial_config = config_choices[0] if len(config_choices) > 0 else None
initial_split = split_choices[0] if len(split_choices) > 0 else None
with gr.Row():
with gr.Column():
config_selection = gr.Dropdown(
label="Config Name", choices=config_choices, value=initial_config
)
with gr.Column():
split_selection = gr.Dropdown(
label="Split Name", choices=split_choices, value=initial_split
)
with gr.Accordion("Query Suggestion History.", open=False) as accordion:
chatbot = gr.Chatbot(height=200, layout="bubble")
with gr.Row():
query = gr.Textbox(
label="Query Description",
placeholder="Enter a natural language query to generate SQL",
)
with gr.Row():
with gr.Column():
query_btn = gr.Button("Get Suggested Query")
with gr.Column():
clear = gr.ClearButton([query, chatbot], value="Reset Query History")
with gr.Row():
search_out = gr.HTML(label="Search Results")
gr.on(
[show_btn.click, search_in.submit],
fn=get_iframe,
inputs=[search_in],
outputs=[search_out],
).then(
fn=get_table_info,
inputs=[search_in],
outputs=[card_data],
)
gr.on(
[query_btn.click, query.submit],
fn=query_dataset,
inputs=[
search_in,
card_data,
query,
config_selection,
split_selection,
chatbot,
],
outputs=[sql_out, search_out, chatbot],
)
gr.on([query_btn.click], fn=lambda: gr.update(open=True), outputs=[accordion])
if __name__ == "__main__":
demo.launch()