import os import torch import duckdb import spaces import gradio as gr import pandas as pd from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Height of the Tabs Text Area TAB_LINES = 8 # Load Token md_token = os.getenv('MD_TOKEN') print('Connecting to DB...') # Connect to DB conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) if torch.cuda.is_available(): device = torch.device("cuda") print(f"Using GPU: {torch.cuda.get_device_name(device)}") else: device = torch.device("cpu") print("Using CPU") print('Loading Model...') tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type= "nf4") model = AutoModelForCausalLM.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1", quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16) print('Model Loaded...') print(f'Model Device: {model.device}') # Get Databases def get_schemas(): schemas = conn.execute(""" SELECT DISTINCT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog') """).fetchall() return [item[0] for item in schemas] # Get Tables def get_tables(schema_name): tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() return [table[0] for table in tables] # Update Tables def update_tables(schema_name): tables = get_tables(schema_name) return gr.update(choices=tables) # Get Schema def get_table_schema(table): result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() ddl_create = result.iloc[0,0] parent_database = result.iloc[0,1] schema_name = result.iloc[0,2] full_path = f"{parent_database}.{schema_name}.{table}" if schema_name != "main": old_path = f"{schema_name}.{table}" else: old_path = table ddl_create = ddl_create.replace(old_path, full_path) return ddl_create # Get Prompt def get_prompt(schema, query_input): text = f""" ### Instruction: Your task is to generate valid duckdb SQL query to answer the following question. ### Input: Here is the database schema that the SQL query will run on: {schema} ### Question: {query_input} ### Response (use duckdb shorthand if possible): """ return text @spaces.GPU(duration=60) def generate_sql(prompt): input_ids = tokenizer(prompt, return_tensors="pt").input_ids input_token_len = input_ids.shape[1] outputs = model.generate(input_ids.to(model.device), max_new_tokens=1024) result = tokenizer.decode(outputs[0][input_token_len:], skip_special_tokens=True) return result # Generate SQL def text2sql(table, query_input): if table is None: return { table_schema: "", input_prompt: "", generated_query: "", result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}]) } schema = get_table_schema(table) print(f'Schema Generated...') prompt = get_prompt(schema, query_input) print(f'Prompt Generated...') try: print(f'Generating SQL... {model.device}') result = generate_sql(prompt) print('SQL Generated...') except Exception as e: return { table_schema: schema, input_prompt: prompt, generated_query: "", result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}]) } try: query_result = conn.sql(result).df() except Exception as e: return { table_schema: schema, input_prompt: prompt, generated_query: result, result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}]) } return { table_schema: schema, input_prompt: prompt, generated_query: result, result_output:query_result } # Custom CSS styling custom_css = """ .gradio-container { background-color: #f0f4f8; } .logo { max-width: 200px; margin: 20px auto; display: block; } .gr-button { background-color: #4a90e2 !important; } .gr-button:hover { background-color: #3a7bc8 !important; } """ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: gr.Image("logo.png", label=None, show_label=False, container=False, height=100) gr.Markdown("""