richardr1126 commited on
Commit
2153702
Β·
1 Parent(s): 3ff9987
Files changed (2) hide show
  1. app-ngrok.py +60 -11
  2. requirements.txt +2 -1
app-ngrok.py CHANGED
@@ -5,6 +5,7 @@ import requests
5
  from time import sleep
6
  import re
7
  import platform
 
8
  # Additional Firebase imports
9
  import firebase_admin
10
  from firebase_admin import credentials, firestore
@@ -23,6 +24,7 @@ initial_model = "WizardLM/WizardCoder-15B-V1.0"
23
  lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
24
  dataset = "richardr1126/spider-skeleton-context-instruct"
25
 
 
26
  # Initialize Firebase
27
  base64_string = os.getenv('FIREBASE')
28
  base64_bytes = base64_string.encode('utf-8')
@@ -36,18 +38,53 @@ cred = credentials.Certificate(firebase_auth)
36
  firebase_admin.initialize_app(cred)
37
  db = firestore.client()
38
 
39
- def log_to_firestore(input_message, db_info, temperature, response_text):
40
  doc_ref = db.collection('logs').document()
41
  log_data = {
42
  'timestamp': firestore.SERVER_TIMESTAMP,
43
  'temperature': temperature,
44
  'db_info': db_info,
45
  'input': input_message,
46
- 'output': response_text
47
  }
48
  doc_ref.set(log_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # End Firebase code
50
 
 
 
 
 
 
 
 
 
51
  def format(text):
52
  # Split the text by "|", and get the last element in the list which should be the final query
53
  try:
@@ -71,7 +108,7 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
71
  # Format the user's input message
72
  messages = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to sql: {input_message} {db_info}\n\n### Response:\n\n"
73
 
74
- url = "https://e9f4be879d38-8269039109365193683.ngrok-free.app/api/v1/generate"
75
  stop_sequence = stop_sequence.split(",")
76
  stop = ["###"] + stop_sequence
77
  payload = {
@@ -104,7 +141,7 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
104
 
105
  if log:
106
  # Log the request to Firestore
107
- log_to_firestore(input_message, db_info, temperature, output if format_sql else response_text)
108
 
109
  return output
110
 
@@ -120,16 +157,24 @@ with gr.Blocks(theme='gradio/soft') as demo:
120
  header = gr.HTML("""
121
  <h1 style="text-align: center">SQL Skeleton WizardCoder Demo</h1>
122
  <h3 style="text-align: center">πŸ•·οΈβ˜ οΈπŸ§™β€β™‚οΈ Generate SQL queries from Natural Language πŸ•·οΈβ˜ οΈπŸ§™β€β™‚οΈ</h3>
 
 
123
  """)
124
 
125
  output_box = gr.Code(label="Generated SQL", lines=2, interactive=True)
126
- note = gr.HTML("""<p style="font-size: 12px; text-align: center">⚠️ Should take 30-60s to generate</p>""")
 
 
 
 
 
127
  input_text = gr.Textbox(lines=3, placeholder='Write your question here...', label='NL Input')
128
  db_info = gr.Textbox(lines=4, placeholder='Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info')
129
  format_sql = gr.Checkbox(label="Format SQL + Remove Skeleton", value=True, interactive=True)
130
 
131
- # Generate button UI element
132
- run_button = gr.Button("Generate SQL", variant="primary")
 
133
 
134
  with gr.Accordion("Options", open=False):
135
  temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.2, step=0.1)
@@ -138,12 +183,11 @@ with gr.Blocks(theme='gradio/soft') as demo:
138
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
139
  stop_sequence = gr.Textbox(lines=1, value="Explanation,Note", label='Extra Stop Sequence')
140
 
141
- ## Add statement saying that inputs/outpus are sent to firebase
142
  info = gr.HTML(f"""
143
  <p>🌐 Leveraging the <a href='https://huggingface.co/{quantized_model}'><strong>4-bit GGML version</strong></a> of <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a> model.</p>
144
  <p>πŸ”— How it's made: <a href='https://huggingface.co/{initial_model}'><strong>{initial_model}</strong></a> was finetuned to create <a href='https://huggingface.co/{lora_model}'><strong>{lora_model}</strong></a>, then merged together to create <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a>.</p>
145
  <p>πŸ“‰ Fine-tuning was performed using QLoRA techniques on the <a href='https://huggingface.co/datasets/{dataset}'><strong>{dataset}</strong></a> dataset. You can view training metrics on the <a href='https://huggingface.co/{lora_model}'><strong>QLoRa adapter HF Repo</strong></a>.</p>
146
- <p>πŸ“Š All inputs/outputs are logged to Firebase to see how the model is doing.</a></p>
147
  """)
148
 
149
  examples = gr.Examples([
@@ -172,11 +216,16 @@ with gr.Blocks(theme='gradio/soft') as demo:
172
  readme_content,
173
  )
174
 
175
- with gr.Accordion("More Options:", open=False):
176
  log = gr.Checkbox(label="Log to Firebase", value=True, interactive=False)
177
 
178
  # When the button is clicked, call the generate function, inputs are taken from the UI elements, outputs are sent to outputs elements
179
  run_button.click(fn=generate, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql, stop_sequence, log], outputs=output_box, api_name="txt2sql")
180
-
 
 
 
 
 
181
 
182
  demo.queue(concurrency_count=1, max_size=20).launch(debug=True)
 
5
  from time import sleep
6
  import re
7
  import platform
8
+ import pyperclip
9
  # Additional Firebase imports
10
  import firebase_admin
11
  from firebase_admin import credentials, firestore
 
24
  lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
25
  dataset = "richardr1126/spider-skeleton-context-instruct"
26
 
27
+ # Firebase code
28
  # Initialize Firebase
29
  base64_string = os.getenv('FIREBASE')
30
  base64_bytes = base64_string.encode('utf-8')
 
38
  firebase_admin.initialize_app(cred)
39
  db = firestore.client()
40
 
41
+ def log_message_to_firestore(input_message, db_info, temperature, response_text):
42
  doc_ref = db.collection('logs').document()
43
  log_data = {
44
  'timestamp': firestore.SERVER_TIMESTAMP,
45
  'temperature': temperature,
46
  'db_info': db_info,
47
  'input': input_message,
48
+ 'output': response_text,
49
  }
50
  doc_ref.set(log_data)
51
+
52
+ rated_outputs = set() # set to store already rated outputs
53
+
54
+ def log_rating_to_firestore(input_message, db_info, temperature, response_text, rating):
55
+ global rated_outputs
56
+ output_id = f"{input_message} {db_info} {response_text} {temperature}"
57
+
58
+ if output_id in rated_outputs:
59
+ gr.Warning("You've already rated this output!")
60
+ return
61
+ if not input_message or not db_info or not response_text or not rating:
62
+ gr.Info("You haven't asked a question yet! Or the output box is empty.")
63
+ return
64
+
65
+ rated_outputs.add(output_id)
66
+
67
+ doc_ref = db.collection('ratings').document()
68
+ log_data = {
69
+ 'timestamp': firestore.SERVER_TIMESTAMP,
70
+ 'temperature': temperature,
71
+ 'db_info': db_info,
72
+ 'input': input_message,
73
+ 'output': response_text,
74
+ 'rating': rating,
75
+ }
76
+ doc_ref.set(log_data)
77
+ gr.Info("Thanks for your feedback!")
78
  # End Firebase code
79
 
80
+ def copy_to_clipboard(text):
81
+ # Copy to clipboard
82
+ try:
83
+ pyperclip.copy(text)
84
+ gr.Info("Copied to clipboard!")
85
+ except Exception:
86
+ gr.Warning("Couldn't copy to clipboard :(")
87
+
88
  def format(text):
89
  # Split the text by "|", and get the last element in the list which should be the final query
90
  try:
 
108
  # Format the user's input message
109
  messages = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to sql: {input_message} {db_info}\n\n### Response:\n\n"
110
 
111
+ url = os.getenv("KOBOLDCPP_API_URL")
112
  stop_sequence = stop_sequence.split(",")
113
  stop = ["###"] + stop_sequence
114
  payload = {
 
141
 
142
  if log:
143
  # Log the request to Firestore
144
+ log_message_to_firestore(input_message, db_info, temperature, output if format_sql else response_text)
145
 
146
  return output
147
 
 
157
  header = gr.HTML("""
158
  <h1 style="text-align: center">SQL Skeleton WizardCoder Demo</h1>
159
  <h3 style="text-align: center">πŸ•·οΈβ˜ οΈπŸ§™β€β™‚οΈ Generate SQL queries from Natural Language πŸ•·οΈβ˜ οΈπŸ§™β€β™‚οΈ</h3>
160
+ <br>
161
+ <p style="font-size: 12px; text-align: center">⚠️ Should take 30-60s to generate. Please rate the response, it helps a lot.</p>
162
  """)
163
 
164
  output_box = gr.Code(label="Generated SQL", lines=2, interactive=True)
165
+
166
+ with gr.Row():
167
+ copy_button = gr.Button("πŸ“‹ Copy SQL", variant="secondary")
168
+ rate_up = gr.Button("πŸ‘", variant="secondary")
169
+ rate_down = gr.Button("πŸ‘Ž", variant="secondary")
170
+
171
  input_text = gr.Textbox(lines=3, placeholder='Write your question here...', label='NL Input')
172
  db_info = gr.Textbox(lines=4, placeholder='Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info')
173
  format_sql = gr.Checkbox(label="Format SQL + Remove Skeleton", value=True, interactive=True)
174
 
175
+ with gr.Row():
176
+ run_button = gr.Button("Generate SQL", variant="primary")
177
+ clear_button = gr.ClearButton(variant="secondary")
178
 
179
  with gr.Accordion("Options", open=False):
180
  temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.2, step=0.1)
 
183
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
184
  stop_sequence = gr.Textbox(lines=1, value="Explanation,Note", label='Extra Stop Sequence')
185
 
 
186
  info = gr.HTML(f"""
187
  <p>🌐 Leveraging the <a href='https://huggingface.co/{quantized_model}'><strong>4-bit GGML version</strong></a> of <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a> model.</p>
188
  <p>πŸ”— How it's made: <a href='https://huggingface.co/{initial_model}'><strong>{initial_model}</strong></a> was finetuned to create <a href='https://huggingface.co/{lora_model}'><strong>{lora_model}</strong></a>, then merged together to create <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a>.</p>
189
  <p>πŸ“‰ Fine-tuning was performed using QLoRA techniques on the <a href='https://huggingface.co/datasets/{dataset}'><strong>{dataset}</strong></a> dataset. You can view training metrics on the <a href='https://huggingface.co/{lora_model}'><strong>QLoRa adapter HF Repo</strong></a>.</p>
190
+ <p>πŸ“Š All inputs/outputs are logged to Firebase to see how the model is doing. You can also leave a rating for each generated SQL the model produces, which gets sent to the database as well.</a></p>
191
  """)
192
 
193
  examples = gr.Examples([
 
216
  readme_content,
217
  )
218
 
219
+ with gr.Accordion("Disabled Options:", open=False):
220
  log = gr.Checkbox(label="Log to Firebase", value=True, interactive=False)
221
 
222
  # When the button is clicked, call the generate function, inputs are taken from the UI elements, outputs are sent to outputs elements
223
  run_button.click(fn=generate, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql, stop_sequence, log], outputs=output_box, api_name="txt2sql")
224
+ copy_button.click(fn=copy_to_clipboard, inputs=[output_box])
225
+ clear_button.add([input_text, db_info, output_box])
226
+
227
+ # Firebase code - for rating the generated SQL (remove if you don't want to use Firebase)
228
+ rate_up.click(fn=log_rating_to_firestore, inputs=[input_text, db_info, temperature, output_box, rate_up])
229
+ rate_down.click(fn=log_rating_to_firestore, inputs=[input_text, db_info, temperature, output_box, rate_down])
230
 
231
  demo.queue(concurrency_count=1, max_size=20).launch(debug=True)
requirements.txt CHANGED
@@ -8,4 +8,5 @@ scipy
8
  transformers
9
  accelerate
10
  sqlparse
11
- firebase_admin
 
 
8
  transformers
9
  accelerate
10
  sqlparse
11
+ firebase_admin
12
+ pyperclip