Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import duckdb | |
| import spaces | |
| import gradio as gr | |
| import pandas as pd | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| # Height of the Tabs Text Area | |
| TAB_LINES = 8 | |
| # Load Token | |
| md_token = os.getenv('MD_TOKEN') | |
| print('Connecting to DB...') | |
| # Connect to DB | |
| conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| print(f"Using GPU: {torch.cuda.get_device_name(device)}") | |
| else: | |
| device = torch.device("cpu") | |
| print("Using CPU") | |
| print('Loading Model...') | |
| tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type= "nf4") | |
| model = AutoModelForCausalLM.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1", quantization_config=quantization_config, | |
| device_map="auto", torch_dtype=torch.bfloat16) | |
| print('Model Loaded...') | |
| print(f'Model Device: {model.device}') | |
| # Get Databases | |
| def get_schemas(): | |
| schemas = conn.execute(""" | |
| SELECT DISTINCT schema_name | |
| FROM information_schema.schemata | |
| WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
| """).fetchall() | |
| return [item[0] for item in schemas] | |
| # Get Tables | |
| def get_tables(schema_name): | |
| tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() | |
| return [table[0] for table in tables] | |
| # Update Tables | |
| def update_tables(schema_name): | |
| tables = get_tables(schema_name) | |
| return gr.update(choices=tables) | |
| # Get Schema | |
| def get_table_schema(table): | |
| result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() | |
| ddl_create = result.iloc[0,0] | |
| parent_database = result.iloc[0,1] | |
| schema_name = result.iloc[0,2] | |
| full_path = f"{parent_database}.{schema_name}.{table}" | |
| if schema_name != "main": | |
| old_path = f"{schema_name}.{table}" | |
| else: | |
| old_path = table | |
| ddl_create = ddl_create.replace(old_path, full_path) | |
| return ddl_create | |
| # Get Prompt | |
| def get_prompt(schema, query_input): | |
| text = f""" | |
| ### Instruction: | |
| Your task is to generate valid duckdb SQL query to answer the following question. | |
| ### Input: | |
| Here is the database schema that the SQL query will run on: | |
| {schema} | |
| ### Question: | |
| {query_input} | |
| ### Response (use duckdb shorthand if possible): | |
| """ | |
| return text | |
| def generate_sql(prompt): | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
| input_token_len = input_ids.shape[1] | |
| outputs = model.generate(input_ids.to(model.device), max_new_tokens=1024) | |
| result = tokenizer.decode(outputs[0][input_token_len:], skip_special_tokens=True) | |
| return result | |
| # Generate SQL | |
| def text2sql(table, query_input): | |
| if table is None: | |
| return { | |
| table_schema: "", | |
| input_prompt: "", | |
| generated_query: "", | |
| result_output:pd.DataFrame([{"error": f"β Unable to get the SQL query based on the text. {e}"}]) | |
| } | |
| schema = get_table_schema(table) | |
| print(f'Schema Generated...') | |
| prompt = get_prompt(schema, query_input) | |
| print(f'Prompt Generated...') | |
| try: | |
| print(f'Generating SQL... {model.device}') | |
| result = generate_sql(prompt) | |
| print('SQL Generated...') | |
| except Exception as e: | |
| return { | |
| table_schema: schema, | |
| input_prompt: prompt, | |
| generated_query: "", | |
| result_output:pd.DataFrame([{"error": f"β Unable to get the SQL query based on the text. {e}"}]) | |
| } | |
| try: | |
| query_result = conn.sql(result).df() | |
| except Exception as e: | |
| return { | |
| table_schema: schema, | |
| input_prompt: prompt, | |
| generated_query: result, | |
| result_output:pd.DataFrame([{"error": f"β Unable to get the SQL query based on the text. {e}"}]) | |
| } | |
| return { | |
| table_schema: schema, | |
| input_prompt: prompt, | |
| generated_query: result, | |
| result_output:query_result | |
| } | |
| # Custom CSS styling | |
| custom_css = """ | |
| .gradio-container { | |
| background-color: #f0f4f8; | |
| } | |
| .logo { | |
| max-width: 200px; | |
| margin: 20px auto; | |
| display: block; | |
| } | |
| .gr-button { | |
| background-color: #4a90e2 !important; | |
| } | |
| .gr-button:hover { | |
| background-color: #3a7bc8 !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: | |
| gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
| gr.Markdown(""" | |
| <div style='text-align: center;'> | |
| <strong style='font-size: 36px;'>Datajoi SQL Agent</strong> | |
| <br> | |
| <span style='font-size: 20px;'>Generate and Run SQL queries based on a given text for the dataset.</span> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, variant='panel'): | |
| schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) | |
| tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) | |
| with gr.Column(scale=2): | |
| query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...") | |
| with gr.Row(): | |
| with gr.Column(scale=7): | |
| pass | |
| with gr.Column(scale=1): | |
| generate_query_button = gr.Button("Run Query", variant="primary") | |
| with gr.Tabs(): | |
| with gr.Tab("Result"): | |
| result_output = gr.DataFrame(label="Query Results", value=[], interactive=False) | |
| with gr.Tab("SQL Query"): | |
| generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False) | |
| with gr.Tab("Prompt"): | |
| input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False) | |
| with gr.Tab("Schema"): | |
| table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False) | |
| schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown) | |
| generate_query_button.click(text2sql, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output]) | |
| if __name__ == "__main__": | |
| demo.launch() | |