Spaces:
Sleeping
Sleeping
Mustehson
commited on
Commit
·
f603f74
1
Parent(s):
499f079
Refactoring
Browse files- app.py +18 -21
- requirements.txt +1 -0
app.py
CHANGED
@@ -7,26 +7,26 @@ import pandas as pd
|
|
7 |
from langchain_huggingface.llms import HuggingFacePipeline
|
8 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
9 |
from langsmith import traceable
|
10 |
-
|
11 |
|
12 |
# Height of the Tabs Text Area
|
13 |
TAB_LINES = 8
|
14 |
-
# Load Token
|
15 |
-
md_token = os.getenv('MD_TOKEN')
|
16 |
|
17 |
-
|
18 |
-
|
|
|
19 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
20 |
|
|
|
21 |
if torch.cuda.is_available():
|
22 |
device = torch.device("cuda")
|
23 |
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
|
24 |
else:
|
25 |
device = torch.device("cpu")
|
26 |
print("Using CPU")
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
|
31 |
|
32 |
quantization_config = BitsAndBytesConfig(
|
@@ -40,9 +40,15 @@ model = AutoModelForCausalLM.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1",
|
|
40 |
|
41 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False)
|
42 |
hf = HuggingFacePipeline(pipeline=pipe)
|
43 |
-
|
44 |
-
print(f'Model Device: {model.device}')
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
# Get Databases
|
47 |
def get_schemas():
|
48 |
schemas = conn.execute("""
|
@@ -78,24 +84,15 @@ def get_table_schema(table):
|
|
78 |
|
79 |
# Get Prompt
|
80 |
def get_prompt(schema, query_input):
|
81 |
-
|
82 |
-
### Instruction:
|
83 |
-
Your task is to generate valid duckdb SQL query to answer the following question.
|
84 |
-
### Input:
|
85 |
-
Here is the database schema that the SQL query will run on:
|
86 |
-
{schema}
|
87 |
-
|
88 |
-
### Question:
|
89 |
-
{query_input}
|
90 |
-
### Response (use duckdb shorthand if possible):
|
91 |
-
"""
|
92 |
-
return text
|
93 |
|
94 |
@spaces.GPU(duration=60)
|
95 |
@traceable()
|
96 |
def generate_sql(prompt):
|
97 |
result = hf.invoke(prompt)
|
98 |
return result.strip()
|
|
|
|
|
99 |
|
100 |
# Generate SQL
|
101 |
def text2sql(table, query_input):
|
|
|
7 |
from langchain_huggingface.llms import HuggingFacePipeline
|
8 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
9 |
from langsmith import traceable
|
10 |
+
from langchain import hub
|
11 |
|
12 |
# Height of the Tabs Text Area
|
13 |
TAB_LINES = 8
|
|
|
|
|
14 |
|
15 |
+
|
16 |
+
#----------CONNECT TO DATABASE----------
|
17 |
+
md_token = os.getenv('MD_TOKEN')
|
18 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
19 |
|
20 |
+
#---------------------------------------
|
21 |
if torch.cuda.is_available():
|
22 |
device = torch.device("cuda")
|
23 |
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
|
24 |
else:
|
25 |
device = torch.device("cpu")
|
26 |
print("Using CPU")
|
27 |
+
#---------------------------------------
|
28 |
|
29 |
+
#-------LOAD HUGGINGFACE PIPELINE-------
|
|
|
30 |
tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
|
31 |
|
32 |
quantization_config = BitsAndBytesConfig(
|
|
|
40 |
|
41 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False)
|
42 |
hf = HuggingFacePipeline(pipeline=pipe)
|
43 |
+
#---------------------------------------
|
|
|
44 |
|
45 |
+
#-----LOAD PROMPT FROM LANCHAIN HUB-----
|
46 |
+
prompt = hub.pull("sql-agent-prompt")
|
47 |
+
#---------------------------------------
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
#--------------ALL UTILS----------------
|
52 |
# Get Databases
|
53 |
def get_schemas():
|
54 |
schemas = conn.execute("""
|
|
|
84 |
|
85 |
# Get Prompt
|
86 |
def get_prompt(schema, query_input):
|
87 |
+
return prompt.format(schema=schema, query_input=query_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
@spaces.GPU(duration=60)
|
90 |
@traceable()
|
91 |
def generate_sql(prompt):
|
92 |
result = hf.invoke(prompt)
|
93 |
return result.strip()
|
94 |
+
#---------------------------------------
|
95 |
+
|
96 |
|
97 |
# Generate SQL
|
98 |
def text2sql(table, query_input):
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ bitsandbytes==0.44.1
|
|
3 |
transformers==4.44.2
|
4 |
duckdb==1.1.1
|
5 |
langsmith==0.1.135
|
|
|
6 |
langchain-huggingface
|
|
|
3 |
transformers==4.44.2
|
4 |
duckdb==1.1.1
|
5 |
langsmith==0.1.135
|
6 |
+
langchain==0.3.4
|
7 |
langchain-huggingface
|