Commit
·
f3486de
1
Parent(s):
189817d
Tests exec on db before output
Browse files- app.py +88 -33
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,18 +1,18 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
|
|
3 |
import sqlparse
|
4 |
import requests
|
5 |
from time import sleep
|
6 |
import re
|
7 |
import platform
|
|
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
StoppingCriteria,
|
12 |
StoppingCriteriaList,
|
13 |
-
TextIteratorStreamer
|
14 |
)
|
15 |
-
from threading import Event, Thread
|
16 |
# Additional Firebase imports
|
17 |
import firebase_admin
|
18 |
from firebase_admin import credentials, firestore
|
@@ -20,7 +20,6 @@ import json
|
|
20 |
import base64
|
21 |
import torch
|
22 |
|
23 |
-
|
24 |
print(f"Running on {platform.system()}")
|
25 |
|
26 |
if platform.system() == "Windows" or platform.system() == "Darwin":
|
@@ -33,7 +32,25 @@ initial_model = "WizardLM/WizardCoder-15B-V1.0"
|
|
33 |
lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
|
34 |
dataset = "richardr1126/spider-skeleton-context-instruct"
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
# Initialize Firebase
|
38 |
base64_string = os.getenv('FIREBASE')
|
39 |
base64_bytes = base64_string.encode('utf-8')
|
@@ -84,7 +101,7 @@ def log_rating_to_firestore(input_message, db_info, temperature, response_text,
|
|
84 |
}
|
85 |
doc_ref.set(log_data)
|
86 |
gr.Info("Thanks for your feedback!")
|
87 |
-
|
88 |
|
89 |
def format(text):
|
90 |
# Split the text by "|", and get the last element in the list which should be the final query
|
@@ -105,23 +122,63 @@ def format(text):
|
|
105 |
|
106 |
return final_query_markdown
|
107 |
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
device_map=0,
|
118 |
-
#load_in_8bit=True,
|
119 |
-
)
|
120 |
|
121 |
-
|
122 |
-
|
123 |
|
124 |
-
|
|
|
|
|
125 |
|
126 |
|
127 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
@@ -139,7 +196,6 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
|
|
139 |
|
140 |
input_ids = tok(messages, return_tensors="pt").input_ids
|
141 |
input_ids = input_ids.to(m.device)
|
142 |
-
#streamer = TextIteratorStreamer(tok, timeout=1000.0, skip_prompt=True, skip_special_tokens=True)
|
143 |
generate_kwargs = dict(
|
144 |
input_ids=input_ids,
|
145 |
max_new_tokens=max_new_tokens,
|
@@ -154,15 +210,6 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
|
|
154 |
do_sample=do_sample,
|
155 |
)
|
156 |
|
157 |
-
#stream_complete = Event()
|
158 |
-
|
159 |
-
# def generate_and_signal_complete():
|
160 |
-
# m.generate(**generate_kwargs)
|
161 |
-
# stream_complete.set()
|
162 |
-
|
163 |
-
# t1 = Thread(target=generate_and_signal_complete)
|
164 |
-
# t1.start()
|
165 |
-
|
166 |
tokens = m.generate(**generate_kwargs)
|
167 |
|
168 |
responses = []
|
@@ -172,14 +219,21 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
|
|
172 |
# Only take what comes after ### Response:
|
173 |
response_text = response_text.split("### Response:")[1].strip()
|
174 |
|
175 |
-
|
176 |
if (num_return_sequences > 1):
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
# Concat responses to be a single string seperated by a newline
|
182 |
-
output = "\n".join(responses)
|
|
|
183 |
|
184 |
if log:
|
185 |
# Log the request to Firestore
|
@@ -219,7 +273,8 @@ with gr.Blocks(theme='gradio/soft') as demo:
|
|
219 |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
|
220 |
|
221 |
with gr.Accordion("Generation strategies", open=False):
|
222 |
-
|
|
|
223 |
num_beams = gr.Slider(label="Num Beams", minimum=1, maximum=5, value=1, step=1)
|
224 |
do_sample = gr.Checkbox(label="Do Sample", value=False, interactive=True)
|
225 |
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
import sqlite3
|
4 |
import sqlparse
|
5 |
import requests
|
6 |
from time import sleep
|
7 |
import re
|
8 |
import platform
|
9 |
+
import openai
|
10 |
from transformers import (
|
11 |
AutoModelForCausalLM,
|
12 |
AutoTokenizer,
|
13 |
StoppingCriteria,
|
14 |
StoppingCriteriaList,
|
|
|
15 |
)
|
|
|
16 |
# Additional Firebase imports
|
17 |
import firebase_admin
|
18 |
from firebase_admin import credentials, firestore
|
|
|
20 |
import base64
|
21 |
import torch
|
22 |
|
|
|
23 |
print(f"Running on {platform.system()}")
|
24 |
|
25 |
if platform.system() == "Windows" or platform.system() == "Darwin":
|
|
|
32 |
lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
|
33 |
dataset = "richardr1126/spider-skeleton-context-instruct"
|
34 |
|
35 |
+
model_name = os.getenv("HF_MODEL_NAME", None)
|
36 |
+
tok = AutoTokenizer.from_pretrained(model_name)
|
37 |
+
|
38 |
+
max_new_tokens = 1024
|
39 |
+
|
40 |
+
print(f"Starting to load the model {model_name}")
|
41 |
+
|
42 |
+
m = AutoModelForCausalLM.from_pretrained(
|
43 |
+
model_name,
|
44 |
+
device_map=0,
|
45 |
+
#load_in_8bit=True,
|
46 |
+
)
|
47 |
+
|
48 |
+
m.config.pad_token_id = m.config.eos_token_id
|
49 |
+
m.generation_config.pad_token_id = m.config.eos_token_id
|
50 |
+
|
51 |
+
print(f"Successfully loaded the model {model_name} into memory")
|
52 |
+
|
53 |
+
################# Firebase code #################
|
54 |
# Initialize Firebase
|
55 |
base64_string = os.getenv('FIREBASE')
|
56 |
base64_bytes = base64_string.encode('utf-8')
|
|
|
101 |
}
|
102 |
doc_ref.set(log_data)
|
103 |
gr.Info("Thanks for your feedback!")
|
104 |
+
############### End Firebase code ###############
|
105 |
|
106 |
def format(text):
|
107 |
# Split the text by "|", and get the last element in the list which should be the final query
|
|
|
122 |
|
123 |
return final_query_markdown
|
124 |
|
125 |
+
def extract_db_code(text):
|
126 |
+
pattern = r'```(?:\w+)?\s?(.*?)```'
|
127 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
128 |
+
return [match.strip() for match in matches]
|
129 |
+
|
130 |
+
def generate_dummy_db(db_info, question, query):
|
131 |
+
pre_prompt = "Generate a SQLite database with dummy data for this database, output the SQL code in a SQL code block. Make sure you add dummy data relevant to the question and query.\n\n"
|
132 |
+
prompt = pre_prompt + db_info + "\n\nQuestion: " + question + "\nQuery: " + query
|
133 |
+
|
134 |
+
while True:
|
135 |
+
try:
|
136 |
+
response = openai.ChatCompletion.create(
|
137 |
+
model="gpt-3.5-turbo",
|
138 |
+
messages=[
|
139 |
+
{"role": "user", "content": prompt}
|
140 |
+
],
|
141 |
+
#temperature=0.7,
|
142 |
+
)
|
143 |
+
response_text = response['choices'][0]['message']['content']
|
144 |
+
|
145 |
+
db_code = extract_db_code(response_text)
|
146 |
+
|
147 |
+
return db_code
|
148 |
+
|
149 |
+
except Exception as e:
|
150 |
+
print(f'Error occurred: {str(e)}')
|
151 |
+
print('Waiting for 20 seconds before retrying...')
|
152 |
+
time.sleep(20)
|
153 |
+
|
154 |
+
def test_query_on_dummy_db(db_code, query):
|
155 |
+
try:
|
156 |
+
# Connect to an SQLite database in memory
|
157 |
+
conn = sqlite3.connect(':memory:')
|
158 |
+
cursor = conn.cursor()
|
159 |
|
160 |
+
# Iterate over each extracted SQL block and split them into individual commands
|
161 |
+
for sql_block in db_code:
|
162 |
+
statements = sqlparse.split(sql_block)
|
163 |
+
|
164 |
+
# Execute each SQL command
|
165 |
+
for statement in statements:
|
166 |
+
if statement:
|
167 |
+
cursor.execute(statement)
|
168 |
|
169 |
+
# Run the provided test query against the database
|
170 |
+
cursor.execute(query)
|
171 |
+
print(cursor.fetchall())
|
172 |
|
173 |
+
# Close the connection
|
174 |
+
conn.close()
|
|
|
|
|
|
|
175 |
|
176 |
+
# If everything executed without errors, return True
|
177 |
+
return True
|
178 |
|
179 |
+
except Exception as e:
|
180 |
+
print(f"Error encountered: {e}")
|
181 |
+
return False
|
182 |
|
183 |
|
184 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
|
|
196 |
|
197 |
input_ids = tok(messages, return_tensors="pt").input_ids
|
198 |
input_ids = input_ids.to(m.device)
|
|
|
199 |
generate_kwargs = dict(
|
200 |
input_ids=input_ids,
|
201 |
max_new_tokens=max_new_tokens,
|
|
|
210 |
do_sample=do_sample,
|
211 |
)
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
tokens = m.generate(**generate_kwargs)
|
214 |
|
215 |
responses = []
|
|
|
219 |
# Only take what comes after ### Response:
|
220 |
response_text = response_text.split("### Response:")[1].strip()
|
221 |
|
222 |
+
query = format(response_text) if format_sql else response_text
|
223 |
if (num_return_sequences > 1):
|
224 |
+
query = query.replace("\n", " ").replace("\t", " ").strip()
|
225 |
+
# Test against dummy database
|
226 |
+
db_code = generate_dummy_db(db_info, input_message, query)
|
227 |
+
success = test_query_on_dummy_db(db_code, query)
|
228 |
+
if success:
|
229 |
+
responses.append(query)
|
230 |
+
else:
|
231 |
+
responses.append(query)
|
232 |
+
|
233 |
|
234 |
# Concat responses to be a single string seperated by a newline
|
235 |
+
#output = "\n".join(responses)
|
236 |
+
output = responses[0] if responses else ""
|
237 |
|
238 |
if log:
|
239 |
# Log the request to Firestore
|
|
|
273 |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
|
274 |
|
275 |
with gr.Accordion("Generation strategies", open=False):
|
276 |
+
md_description = gr.Markdown("""Increasing num return sequences will increase the number of SQLs generated, but will still yield only the best output of the number of return sequences. SQLs are tested against the db info you provide.""")
|
277 |
+
num_return_sequences = gr.Slider(label="Number of return sequences (to generate and test)", minimum=1, maximum=5, value=1, step=1)
|
278 |
num_beams = gr.Slider(label="Num Beams", minimum=1, maximum=5, value=1, step=1)
|
279 |
do_sample = gr.Checkbox(label="Do Sample", value=False, interactive=True)
|
280 |
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ scipy
|
|
8 |
transformers
|
9 |
accelerate
|
10 |
sqlparse
|
11 |
-
firebase_admin
|
|
|
|
8 |
transformers
|
9 |
accelerate
|
10 |
sqlparse
|
11 |
+
firebase_admin
|
12 |
+
openai
|