Mustehson commited on
Commit
f603f74
·
1 Parent(s): 499f079

Refactoring

Browse files
Files changed (2) hide show
  1. app.py +18 -21
  2. requirements.txt +1 -0
app.py CHANGED
@@ -7,26 +7,26 @@ import pandas as pd
7
  from langchain_huggingface.llms import HuggingFacePipeline
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
9
  from langsmith import traceable
10
-
11
 
12
  # Height of the Tabs Text Area
13
  TAB_LINES = 8
14
- # Load Token
15
- md_token = os.getenv('MD_TOKEN')
16
 
17
- print('Connecting to DB...')
18
- # Connect to DB
 
19
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
20
 
 
21
  if torch.cuda.is_available():
22
  device = torch.device("cuda")
23
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
24
  else:
25
  device = torch.device("cpu")
26
  print("Using CPU")
 
27
 
28
- print('Loading Model...')
29
-
30
  tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
31
 
32
  quantization_config = BitsAndBytesConfig(
@@ -40,9 +40,15 @@ model = AutoModelForCausalLM.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1",
40
 
41
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False)
42
  hf = HuggingFacePipeline(pipeline=pipe)
43
- print('Model Loaded...')
44
- print(f'Model Device: {model.device}')
45
 
 
 
 
 
 
 
 
46
  # Get Databases
47
  def get_schemas():
48
  schemas = conn.execute("""
@@ -78,24 +84,15 @@ def get_table_schema(table):
78
 
79
  # Get Prompt
80
  def get_prompt(schema, query_input):
81
- text = f"""
82
- ### Instruction:
83
- Your task is to generate valid duckdb SQL query to answer the following question.
84
- ### Input:
85
- Here is the database schema that the SQL query will run on:
86
- {schema}
87
-
88
- ### Question:
89
- {query_input}
90
- ### Response (use duckdb shorthand if possible):
91
- """
92
- return text
93
 
94
  @spaces.GPU(duration=60)
95
  @traceable()
96
  def generate_sql(prompt):
97
  result = hf.invoke(prompt)
98
  return result.strip()
 
 
99
 
100
  # Generate SQL
101
  def text2sql(table, query_input):
 
7
  from langchain_huggingface.llms import HuggingFacePipeline
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
9
  from langsmith import traceable
10
+ from langchain import hub
11
 
12
  # Height of the Tabs Text Area
13
  TAB_LINES = 8
 
 
14
 
15
+
16
+ #----------CONNECT TO DATABASE----------
17
+ md_token = os.getenv('MD_TOKEN')
18
  conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
19
 
20
+ #---------------------------------------
21
  if torch.cuda.is_available():
22
  device = torch.device("cuda")
23
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
24
  else:
25
  device = torch.device("cpu")
26
  print("Using CPU")
27
+ #---------------------------------------
28
 
29
+ #-------LOAD HUGGINGFACE PIPELINE-------
 
30
  tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
31
 
32
  quantization_config = BitsAndBytesConfig(
 
40
 
41
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False)
42
  hf = HuggingFacePipeline(pipeline=pipe)
43
+ #---------------------------------------
 
44
 
45
+ #-----LOAD PROMPT FROM LANCHAIN HUB-----
46
+ prompt = hub.pull("sql-agent-prompt")
47
+ #---------------------------------------
48
+
49
+
50
+
51
+ #--------------ALL UTILS----------------
52
  # Get Databases
53
  def get_schemas():
54
  schemas = conn.execute("""
 
84
 
85
  # Get Prompt
86
  def get_prompt(schema, query_input):
87
+ return prompt.format(schema=schema, query_input=query_input)
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  @spaces.GPU(duration=60)
90
  @traceable()
91
  def generate_sql(prompt):
92
  result = hf.invoke(prompt)
93
  return result.strip()
94
+ #---------------------------------------
95
+
96
 
97
  # Generate SQL
98
  def text2sql(table, query_input):
requirements.txt CHANGED
@@ -3,4 +3,5 @@ bitsandbytes==0.44.1
3
  transformers==4.44.2
4
  duckdb==1.1.1
5
  langsmith==0.1.135
 
6
  langchain-huggingface
 
3
  transformers==4.44.2
4
  duckdb==1.1.1
5
  langsmith==0.1.135
6
+ langchain==0.3.4
7
  langchain-huggingface