richardr1126 commited on
Commit
f3486de
·
1 Parent(s): 189817d

Tests exec on db before output

Browse files
Files changed (2) hide show
  1. app.py +88 -33
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,18 +1,18 @@
1
  import os
2
  import gradio as gr
 
3
  import sqlparse
4
  import requests
5
  from time import sleep
6
  import re
7
  import platform
 
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  StoppingCriteria,
12
  StoppingCriteriaList,
13
- TextIteratorStreamer
14
  )
15
- from threading import Event, Thread
16
  # Additional Firebase imports
17
  import firebase_admin
18
  from firebase_admin import credentials, firestore
@@ -20,7 +20,6 @@ import json
20
  import base64
21
  import torch
22
 
23
-
24
  print(f"Running on {platform.system()}")
25
 
26
  if platform.system() == "Windows" or platform.system() == "Darwin":
@@ -33,7 +32,25 @@ initial_model = "WizardLM/WizardCoder-15B-V1.0"
33
  lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
34
  dataset = "richardr1126/spider-skeleton-context-instruct"
35
 
36
- # Firebase code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Initialize Firebase
38
  base64_string = os.getenv('FIREBASE')
39
  base64_bytes = base64_string.encode('utf-8')
@@ -84,7 +101,7 @@ def log_rating_to_firestore(input_message, db_info, temperature, response_text,
84
  }
85
  doc_ref.set(log_data)
86
  gr.Info("Thanks for your feedback!")
87
- # End Firebase code
88
 
89
  def format(text):
90
  # Split the text by "|", and get the last element in the list which should be the final query
@@ -105,23 +122,63 @@ def format(text):
105
 
106
  return final_query_markdown
107
 
108
- model_name = os.getenv("HF_MODEL_NAME", None)
109
- tok = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- max_new_tokens = 1024
 
 
 
 
 
 
 
112
 
113
- print(f"Starting to load the model {model_name}")
 
 
114
 
115
- m = AutoModelForCausalLM.from_pretrained(
116
- model_name,
117
- device_map=0,
118
- #load_in_8bit=True,
119
- )
120
 
121
- m.config.pad_token_id = m.config.eos_token_id
122
- m.generation_config.pad_token_id = m.config.eos_token_id
123
 
124
- print(f"Successfully loaded the model {model_name} into memory")
 
 
125
 
126
 
127
  def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
@@ -139,7 +196,6 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
139
 
140
  input_ids = tok(messages, return_tensors="pt").input_ids
141
  input_ids = input_ids.to(m.device)
142
- #streamer = TextIteratorStreamer(tok, timeout=1000.0, skip_prompt=True, skip_special_tokens=True)
143
  generate_kwargs = dict(
144
  input_ids=input_ids,
145
  max_new_tokens=max_new_tokens,
@@ -154,15 +210,6 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
154
  do_sample=do_sample,
155
  )
156
 
157
- #stream_complete = Event()
158
-
159
- # def generate_and_signal_complete():
160
- # m.generate(**generate_kwargs)
161
- # stream_complete.set()
162
-
163
- # t1 = Thread(target=generate_and_signal_complete)
164
- # t1.start()
165
-
166
  tokens = m.generate(**generate_kwargs)
167
 
168
  responses = []
@@ -172,14 +219,21 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
172
  # Only take what comes after ### Response:
173
  response_text = response_text.split("### Response:")[1].strip()
174
 
175
- formatted_text = format(response_text) if format_sql else response_text
176
  if (num_return_sequences > 1):
177
- formatted_text = formatted_text.replace("\n", " ").replace("\t", " ").strip()
178
-
179
- responses.append(formatted_text)
 
 
 
 
 
 
180
 
181
  # Concat responses to be a single string seperated by a newline
182
- output = "\n".join(responses)
 
183
 
184
  if log:
185
  # Log the request to Firestore
@@ -219,7 +273,8 @@ with gr.Blocks(theme='gradio/soft') as demo:
219
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
220
 
221
  with gr.Accordion("Generation strategies", open=False):
222
- num_return_sequences = gr.Slider(label="Num Return Sequences", minimum=1, maximum=5, value=1, step=1)
 
223
  num_beams = gr.Slider(label="Num Beams", minimum=1, maximum=5, value=1, step=1)
224
  do_sample = gr.Checkbox(label="Do Sample", value=False, interactive=True)
225
 
 
1
  import os
2
  import gradio as gr
3
+ import sqlite3
4
  import sqlparse
5
  import requests
6
  from time import sleep
7
  import re
8
  import platform
9
+ import openai
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  StoppingCriteria,
14
  StoppingCriteriaList,
 
15
  )
 
16
  # Additional Firebase imports
17
  import firebase_admin
18
  from firebase_admin import credentials, firestore
 
20
  import base64
21
  import torch
22
 
 
23
  print(f"Running on {platform.system()}")
24
 
25
  if platform.system() == "Windows" or platform.system() == "Darwin":
 
32
  lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
33
  dataset = "richardr1126/spider-skeleton-context-instruct"
34
 
35
+ model_name = os.getenv("HF_MODEL_NAME", None)
36
+ tok = AutoTokenizer.from_pretrained(model_name)
37
+
38
+ max_new_tokens = 1024
39
+
40
+ print(f"Starting to load the model {model_name}")
41
+
42
+ m = AutoModelForCausalLM.from_pretrained(
43
+ model_name,
44
+ device_map=0,
45
+ #load_in_8bit=True,
46
+ )
47
+
48
+ m.config.pad_token_id = m.config.eos_token_id
49
+ m.generation_config.pad_token_id = m.config.eos_token_id
50
+
51
+ print(f"Successfully loaded the model {model_name} into memory")
52
+
53
+ ################# Firebase code #################
54
  # Initialize Firebase
55
  base64_string = os.getenv('FIREBASE')
56
  base64_bytes = base64_string.encode('utf-8')
 
101
  }
102
  doc_ref.set(log_data)
103
  gr.Info("Thanks for your feedback!")
104
+ ############### End Firebase code ###############
105
 
106
  def format(text):
107
  # Split the text by "|", and get the last element in the list which should be the final query
 
122
 
123
  return final_query_markdown
124
 
125
+ def extract_db_code(text):
126
+ pattern = r'```(?:\w+)?\s?(.*?)```'
127
+ matches = re.findall(pattern, text, re.DOTALL)
128
+ return [match.strip() for match in matches]
129
+
130
+ def generate_dummy_db(db_info, question, query):
131
+ pre_prompt = "Generate a SQLite database with dummy data for this database, output the SQL code in a SQL code block. Make sure you add dummy data relevant to the question and query.\n\n"
132
+ prompt = pre_prompt + db_info + "\n\nQuestion: " + question + "\nQuery: " + query
133
+
134
+ while True:
135
+ try:
136
+ response = openai.ChatCompletion.create(
137
+ model="gpt-3.5-turbo",
138
+ messages=[
139
+ {"role": "user", "content": prompt}
140
+ ],
141
+ #temperature=0.7,
142
+ )
143
+ response_text = response['choices'][0]['message']['content']
144
+
145
+ db_code = extract_db_code(response_text)
146
+
147
+ return db_code
148
+
149
+ except Exception as e:
150
+ print(f'Error occurred: {str(e)}')
151
+ print('Waiting for 20 seconds before retrying...')
152
+ time.sleep(20)
153
+
154
+ def test_query_on_dummy_db(db_code, query):
155
+ try:
156
+ # Connect to an SQLite database in memory
157
+ conn = sqlite3.connect(':memory:')
158
+ cursor = conn.cursor()
159
 
160
+ # Iterate over each extracted SQL block and split them into individual commands
161
+ for sql_block in db_code:
162
+ statements = sqlparse.split(sql_block)
163
+
164
+ # Execute each SQL command
165
+ for statement in statements:
166
+ if statement:
167
+ cursor.execute(statement)
168
 
169
+ # Run the provided test query against the database
170
+ cursor.execute(query)
171
+ print(cursor.fetchall())
172
 
173
+ # Close the connection
174
+ conn.close()
 
 
 
175
 
176
+ # If everything executed without errors, return True
177
+ return True
178
 
179
+ except Exception as e:
180
+ print(f"Error encountered: {e}")
181
+ return False
182
 
183
 
184
  def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
 
196
 
197
  input_ids = tok(messages, return_tensors="pt").input_ids
198
  input_ids = input_ids.to(m.device)
 
199
  generate_kwargs = dict(
200
  input_ids=input_ids,
201
  max_new_tokens=max_new_tokens,
 
210
  do_sample=do_sample,
211
  )
212
 
 
 
 
 
 
 
 
 
 
213
  tokens = m.generate(**generate_kwargs)
214
 
215
  responses = []
 
219
  # Only take what comes after ### Response:
220
  response_text = response_text.split("### Response:")[1].strip()
221
 
222
+ query = format(response_text) if format_sql else response_text
223
  if (num_return_sequences > 1):
224
+ query = query.replace("\n", " ").replace("\t", " ").strip()
225
+ # Test against dummy database
226
+ db_code = generate_dummy_db(db_info, input_message, query)
227
+ success = test_query_on_dummy_db(db_code, query)
228
+ if success:
229
+ responses.append(query)
230
+ else:
231
+ responses.append(query)
232
+
233
 
234
  # Concat responses to be a single string seperated by a newline
235
+ #output = "\n".join(responses)
236
+ output = responses[0] if responses else ""
237
 
238
  if log:
239
  # Log the request to Firestore
 
273
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
274
 
275
  with gr.Accordion("Generation strategies", open=False):
276
+ md_description = gr.Markdown("""Increasing num return sequences will increase the number of SQLs generated, but will still yield only the best output of the number of return sequences. SQLs are tested against the db info you provide.""")
277
+ num_return_sequences = gr.Slider(label="Number of return sequences (to generate and test)", minimum=1, maximum=5, value=1, step=1)
278
  num_beams = gr.Slider(label="Num Beams", minimum=1, maximum=5, value=1, step=1)
279
  do_sample = gr.Checkbox(label="Do Sample", value=False, interactive=True)
280
 
requirements.txt CHANGED
@@ -8,4 +8,5 @@ scipy
8
  transformers
9
  accelerate
10
  sqlparse
11
- firebase_admin
 
 
8
  transformers
9
  accelerate
10
  sqlparse
11
+ firebase_admin
12
+ openai