Spaces:
Runtime error
Runtime error
Mustehson
commited on
Commit
Β·
2bcd76f
1
Parent(s):
f603f74
LanceDB
Browse files- .gitignore +1 -0
- app.py +51 -4
- requirements.txt +3 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
app2.py
|
app.py
CHANGED
|
@@ -2,12 +2,15 @@ import os
|
|
| 2 |
import torch
|
| 3 |
import duckdb
|
| 4 |
import spaces
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from langchain_huggingface.llms import HuggingFacePipeline
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
| 9 |
-
from langsmith import traceable
|
| 10 |
-
from langchain import hub
|
| 11 |
|
| 12 |
# Height of the Tabs Text Area
|
| 13 |
TAB_LINES = 8
|
|
@@ -16,8 +19,8 @@ TAB_LINES = 8
|
|
| 16 |
#----------CONNECT TO DATABASE----------
|
| 17 |
md_token = os.getenv('MD_TOKEN')
|
| 18 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
| 19 |
-
|
| 20 |
#---------------------------------------
|
|
|
|
| 21 |
if torch.cuda.is_available():
|
| 22 |
device = torch.device("cuda")
|
| 23 |
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
|
|
@@ -26,6 +29,25 @@ else:
|
|
| 26 |
print("Using CPU")
|
| 27 |
#---------------------------------------
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
#-------LOAD HUGGINGFACE PIPELINE-------
|
| 30 |
tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
|
| 31 |
|
|
@@ -46,7 +68,9 @@ hf = HuggingFacePipeline(pipeline=pipe)
|
|
| 46 |
prompt = hub.pull("sql-agent-prompt")
|
| 47 |
#---------------------------------------
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
|
| 51 |
#--------------ALL UTILS----------------
|
| 52 |
# Get Databases
|
|
@@ -91,6 +115,20 @@ def get_prompt(schema, query_input):
|
|
| 91 |
def generate_sql(prompt):
|
| 92 |
result = hf.invoke(prompt)
|
| 93 |
return result.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
#---------------------------------------
|
| 95 |
|
| 96 |
|
|
@@ -108,6 +146,7 @@ def text2sql(table, query_input):
|
|
| 108 |
print(f'Schema Generated...')
|
| 109 |
prompt = get_prompt(schema, query_input)
|
| 110 |
print(f'Prompt Generated...')
|
|
|
|
| 111 |
try:
|
| 112 |
print(f'Generating SQL... {model.device}')
|
| 113 |
result = generate_sql(prompt)
|
|
@@ -119,6 +158,14 @@ def text2sql(table, query_input):
|
|
| 119 |
generated_query: "",
|
| 120 |
result_output:pd.DataFrame([{"error": f"β Unable to get the SQL query based on the text. {e}"}])
|
| 121 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
try:
|
| 123 |
query_result = conn.sql(result).df()
|
| 124 |
|
|
|
|
| 2 |
import torch
|
| 3 |
import duckdb
|
| 4 |
import spaces
|
| 5 |
+
import lancedb
|
| 6 |
import gradio as gr
|
| 7 |
import pandas as pd
|
| 8 |
+
import pyarrow as pa
|
| 9 |
+
from langchain import hub
|
| 10 |
+
from langsmith import traceable
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
from langchain_huggingface.llms import HuggingFacePipeline
|
| 13 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Height of the Tabs Text Area
|
| 16 |
TAB_LINES = 8
|
|
|
|
| 19 |
#----------CONNECT TO DATABASE----------
|
| 20 |
md_token = os.getenv('MD_TOKEN')
|
| 21 |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
|
|
|
| 22 |
#---------------------------------------
|
| 23 |
+
|
| 24 |
if torch.cuda.is_available():
|
| 25 |
device = torch.device("cuda")
|
| 26 |
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
|
|
|
|
| 29 |
print("Using CPU")
|
| 30 |
#---------------------------------------
|
| 31 |
|
| 32 |
+
#--------------LanceDB-------------
|
| 33 |
+
|
| 34 |
+
lance_db = lancedb.connect(
|
| 35 |
+
uri=os.getenv('lancedb_uri'),
|
| 36 |
+
api_key=os.getenv('lancedb_api_key'),
|
| 37 |
+
region=os.getenv('lancedb_region')
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
lance_schema = pa.schema([
|
| 41 |
+
pa.field("vector", pa.list_(pa.float32())),
|
| 42 |
+
pa.field("sql-query", pa.utf8())
|
| 43 |
+
])
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
table = lance_db.create_table(name="SQL-Queries", schema=lance_schema)
|
| 47 |
+
except:
|
| 48 |
+
table = lance_db.open_table(name="SQL-Queries")
|
| 49 |
+
#---------------------------------------
|
| 50 |
+
|
| 51 |
#-------LOAD HUGGINGFACE PIPELINE-------
|
| 52 |
tokenizer = AutoTokenizer.from_pretrained("motherduckdb/DuckDB-NSQL-7B-v0.1")
|
| 53 |
|
|
|
|
| 68 |
prompt = hub.pull("sql-agent-prompt")
|
| 69 |
#---------------------------------------
|
| 70 |
|
| 71 |
+
#-----LOAD EMBEDDING MODEL-----
|
| 72 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
|
| 73 |
+
#---------------------------------------
|
| 74 |
|
| 75 |
#--------------ALL UTILS----------------
|
| 76 |
# Get Databases
|
|
|
|
| 115 |
def generate_sql(prompt):
|
| 116 |
result = hf.invoke(prompt)
|
| 117 |
return result.strip()
|
| 118 |
+
@spaces.GPU(duration=10)
|
| 119 |
+
def embed_query(sql_query):
|
| 120 |
+
print(f'Creating Emebeddings {sql_query}')
|
| 121 |
+
if sql_query is not None:
|
| 122 |
+
embeddings = embedding_model.encode(sql_query, normalize_embeddings=True).tolist()
|
| 123 |
+
return embeddings
|
| 124 |
+
|
| 125 |
+
def log2lancedb(embeddings, sql_query):
|
| 126 |
+
data = [{
|
| 127 |
+
"sql-query": sql_query,
|
| 128 |
+
"vector": embeddings
|
| 129 |
+
}]
|
| 130 |
+
table.add(data)
|
| 131 |
+
print(f'Added to Lance DB.')
|
| 132 |
#---------------------------------------
|
| 133 |
|
| 134 |
|
|
|
|
| 146 |
print(f'Schema Generated...')
|
| 147 |
prompt = get_prompt(schema, query_input)
|
| 148 |
print(f'Prompt Generated...')
|
| 149 |
+
|
| 150 |
try:
|
| 151 |
print(f'Generating SQL... {model.device}')
|
| 152 |
result = generate_sql(prompt)
|
|
|
|
| 158 |
generated_query: "",
|
| 159 |
result_output:pd.DataFrame([{"error": f"β Unable to get the SQL query based on the text. {e}"}])
|
| 160 |
}
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
embeddings = embed_query(result)
|
| 164 |
+
log2lancedb(embeddings, result)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print("Error Generating and Logging Embeddings...")
|
| 167 |
+
print(e)
|
| 168 |
+
|
| 169 |
try:
|
| 170 |
query_result = conn.sql(result).df()
|
| 171 |
|
requirements.txt
CHANGED
|
@@ -4,4 +4,7 @@ transformers==4.44.2
|
|
| 4 |
duckdb==1.1.1
|
| 5 |
langsmith==0.1.135
|
| 6 |
langchain==0.3.4
|
|
|
|
|
|
|
|
|
|
| 7 |
langchain-huggingface
|
|
|
|
| 4 |
duckdb==1.1.1
|
| 5 |
langsmith==0.1.135
|
| 6 |
langchain==0.3.4
|
| 7 |
+
lancedb==0.15.0
|
| 8 |
+
sentence-transformers==3.2.1
|
| 9 |
+
pyarrow==17.0.0
|
| 10 |
langchain-huggingface
|