from gradio_huggingfacehub_search import HuggingfaceHubSearch from llama_cpp.llama_speculative import LlamaPromptLookupDecoding from huggingface_hub import hf_hub_download from huggingface_hub import HfApi import matplotlib.pyplot as plt from typing import Tuple, Optional import pandas as pd import gradio as gr import duckdb import requests import llama_cpp import instructor import spaces import enum import os from pydantic import BaseModel, Field BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" view_name = "dataset_view" hf_api = HfApi() conn = duckdb.connect() gpu_layers = int(os.environ.get("GPU_LAYERS", 81)) draft_pred_tokens = int(os.environ.get("DRAFT_PRED_TOKENS", 2)) repo_id = os.getenv("MODEL_REPO_ID", "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF") model_file_name = os.getenv("MODEL_FILE_NAME", "Hermes-2-Pro-Llama-3-8B-Q8_0.gguf") hf_hub_download( repo_id=repo_id, filename=model_file_name, local_dir="./models", ) class OutputTypes(str, enum.Enum): TABLE = "table" BARCHART = "barchart" LINECHART = "linechart" class SQLResponse(BaseModel): sql: str visualization_type: Optional[OutputTypes] = Field( None, description="The type of visualization to display" ) data_key: Optional[str] = Field( None, description="The column name from the sql query that contains the data for chart responses", ) label_key: Optional[str] = Field( None, description="The column name from the sql query that contains the labels for chart responses", ) def get_dataset_ddl(dataset_id: str) -> str: response = requests.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset_id}") response.raise_for_status() # Check if the request was successful first_parquet = response.json().get("parquet_files", [])[0] first_parquet_url = first_parquet.get("url") if not first_parquet_url: raise ValueError("No valid URL found for the first parquet file.") conn.execute( f"CREATE OR REPLACE VIEW {view_name} as SELECT * FROM read_parquet('{first_parquet_url}');" ) dataset_ddl = conn.execute(f"PRAGMA table_info('{view_name}');").fetchall() column_data_types = ",\n\t".join( [f"{column[1]} {column[2]}" for column in dataset_ddl] ) sql_ddl = """ CREATE TABLE {} ( {} ); """.format( view_name, column_data_types ) return sql_ddl @spaces.GPU(duration=120) def generate_query(ddl: str, query: str) -> dict: llama = llama_cpp.Llama( model_path=f"models/{model_file_name}", n_gpu_layers=gpu_layers, chat_format="chatml", draft_model=LlamaPromptLookupDecoding(num_pred_tokens=draft_pred_tokens), logits_all=True, n_ctx=2048, verbose=True, temperature=0.1, ) create = instructor.patch( create=llama.create_chat_completion_openai_v1, mode=instructor.Mode.JSON_SCHEMA, ) system_prompt = f""" You are an expert SQL assistant with access to the following PostgreSQL Table: ```sql {ddl.strip()} ``` Please assist the user by writing a SQL query that answers the user's question. """ print("Calling LLM with system prompt: ", system_prompt, query) resp: SQLResponse = create( model="Hermes-2-Pro-Llama-3-8B", messages=[ {"role": "system", "content": system_prompt}, { "role": "user", "content": query, }, ], response_model=SQLResponse, ) print("Received Response: ", resp) return resp.model_dump() def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]: ddl = get_dataset_ddl(dataset_id) response = generate_query(ddl, query) print("Querying Parquet...") df = conn.execute(response.get("sql")).fetchdf() plot = None label_key = response.get("label_key") data_key = response.get("data_key") viz_type = response.get("visualization_type") sql = response.get("sql") markdown_output = f"""```sql\n{sql}\n```""" # handle incorrect data and label keys if label_key and label_key not in df.columns: label_key = None if data_key and data_key not in df.columns: data_key = None if df.empty: return df, f"```sql\n{sql}\n```", plot if viz_type == OutputTypes.LINECHART: plot = df.plot(kind="line", x=label_key, y=data_key).get_figure() plt.xticks(rotation=45, ha="right") plt.tight_layout() elif viz_type == OutputTypes.BARCHART: plot = df.plot(kind="bar", x=label_key, y=data_key).get_figure() plt.xticks(rotation=45, ha="right") plt.tight_layout() return df, markdown_output, plot with gr.Blocks() as demo: gr.Markdown("# Query your HF Datasets with Natural Language 📈📊") dataset_id = HuggingfaceHubSearch( label="Hub Dataset ID", placeholder="Find your favorite dataset...", search_type="dataset", value="gretelai/synthetic_text_to_sql", ) user_query = gr.Textbox("", label="Ask anything...") examples = [ ["Show me a preview of the data"], ["Show me something interesting"], ["Which row has longest description length?"], ["find the average length of sql query context"], ] gr.Examples(examples=examples, inputs=[user_query], outputs=[]) btn = gr.Button("Ask 🪄") sql_query = gr.Markdown(label="Output SQL Query") df = gr.DataFrame() plot = gr.Plot() btn.click( query_dataset, inputs=[dataset_id, user_query], outputs=[df, sql_query, plot], ) if __name__ == "__main__": demo.launch()