Commit
Β·
2153702
1
Parent(s):
3ff9987
Update
Browse files- app-ngrok.py +60 -11
- requirements.txt +2 -1
app-ngrok.py
CHANGED
@@ -5,6 +5,7 @@ import requests
|
|
5 |
from time import sleep
|
6 |
import re
|
7 |
import platform
|
|
|
8 |
# Additional Firebase imports
|
9 |
import firebase_admin
|
10 |
from firebase_admin import credentials, firestore
|
@@ -23,6 +24,7 @@ initial_model = "WizardLM/WizardCoder-15B-V1.0"
|
|
23 |
lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
|
24 |
dataset = "richardr1126/spider-skeleton-context-instruct"
|
25 |
|
|
|
26 |
# Initialize Firebase
|
27 |
base64_string = os.getenv('FIREBASE')
|
28 |
base64_bytes = base64_string.encode('utf-8')
|
@@ -36,18 +38,53 @@ cred = credentials.Certificate(firebase_auth)
|
|
36 |
firebase_admin.initialize_app(cred)
|
37 |
db = firestore.client()
|
38 |
|
39 |
-
def
|
40 |
doc_ref = db.collection('logs').document()
|
41 |
log_data = {
|
42 |
'timestamp': firestore.SERVER_TIMESTAMP,
|
43 |
'temperature': temperature,
|
44 |
'db_info': db_info,
|
45 |
'input': input_message,
|
46 |
-
'output': response_text
|
47 |
}
|
48 |
doc_ref.set(log_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
# End Firebase code
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def format(text):
|
52 |
# Split the text by "|", and get the last element in the list which should be the final query
|
53 |
try:
|
@@ -71,7 +108,7 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
|
|
71 |
# Format the user's input message
|
72 |
messages = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to sql: {input_message} {db_info}\n\n### Response:\n\n"
|
73 |
|
74 |
-
url =
|
75 |
stop_sequence = stop_sequence.split(",")
|
76 |
stop = ["###"] + stop_sequence
|
77 |
payload = {
|
@@ -104,7 +141,7 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
|
|
104 |
|
105 |
if log:
|
106 |
# Log the request to Firestore
|
107 |
-
|
108 |
|
109 |
return output
|
110 |
|
@@ -120,16 +157,24 @@ with gr.Blocks(theme='gradio/soft') as demo:
|
|
120 |
header = gr.HTML("""
|
121 |
<h1 style="text-align: center">SQL Skeleton WizardCoder Demo</h1>
|
122 |
<h3 style="text-align: center">π·οΈβ οΈπ§ββοΈ Generate SQL queries from Natural Language π·οΈβ οΈπ§ββοΈ</h3>
|
|
|
|
|
123 |
""")
|
124 |
|
125 |
output_box = gr.Code(label="Generated SQL", lines=2, interactive=True)
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
127 |
input_text = gr.Textbox(lines=3, placeholder='Write your question here...', label='NL Input')
|
128 |
db_info = gr.Textbox(lines=4, placeholder='Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info')
|
129 |
format_sql = gr.Checkbox(label="Format SQL + Remove Skeleton", value=True, interactive=True)
|
130 |
|
131 |
-
|
132 |
-
|
|
|
133 |
|
134 |
with gr.Accordion("Options", open=False):
|
135 |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.2, step=0.1)
|
@@ -138,12 +183,11 @@ with gr.Blocks(theme='gradio/soft') as demo:
|
|
138 |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
|
139 |
stop_sequence = gr.Textbox(lines=1, value="Explanation,Note", label='Extra Stop Sequence')
|
140 |
|
141 |
-
## Add statement saying that inputs/outpus are sent to firebase
|
142 |
info = gr.HTML(f"""
|
143 |
<p>π Leveraging the <a href='https://huggingface.co/{quantized_model}'><strong>4-bit GGML version</strong></a> of <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a> model.</p>
|
144 |
<p>π How it's made: <a href='https://huggingface.co/{initial_model}'><strong>{initial_model}</strong></a> was finetuned to create <a href='https://huggingface.co/{lora_model}'><strong>{lora_model}</strong></a>, then merged together to create <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a>.</p>
|
145 |
<p>π Fine-tuning was performed using QLoRA techniques on the <a href='https://huggingface.co/datasets/{dataset}'><strong>{dataset}</strong></a> dataset. You can view training metrics on the <a href='https://huggingface.co/{lora_model}'><strong>QLoRa adapter HF Repo</strong></a>.</p>
|
146 |
-
<p>π All inputs/outputs are logged to Firebase to see how the model is doing.</a></p>
|
147 |
""")
|
148 |
|
149 |
examples = gr.Examples([
|
@@ -172,11 +216,16 @@ with gr.Blocks(theme='gradio/soft') as demo:
|
|
172 |
readme_content,
|
173 |
)
|
174 |
|
175 |
-
with gr.Accordion("
|
176 |
log = gr.Checkbox(label="Log to Firebase", value=True, interactive=False)
|
177 |
|
178 |
# When the button is clicked, call the generate function, inputs are taken from the UI elements, outputs are sent to outputs elements
|
179 |
run_button.click(fn=generate, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql, stop_sequence, log], outputs=output_box, api_name="txt2sql")
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
demo.queue(concurrency_count=1, max_size=20).launch(debug=True)
|
|
|
5 |
from time import sleep
|
6 |
import re
|
7 |
import platform
|
8 |
+
import pyperclip
|
9 |
# Additional Firebase imports
|
10 |
import firebase_admin
|
11 |
from firebase_admin import credentials, firestore
|
|
|
24 |
lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
|
25 |
dataset = "richardr1126/spider-skeleton-context-instruct"
|
26 |
|
27 |
+
# Firebase code
|
28 |
# Initialize Firebase
|
29 |
base64_string = os.getenv('FIREBASE')
|
30 |
base64_bytes = base64_string.encode('utf-8')
|
|
|
38 |
firebase_admin.initialize_app(cred)
|
39 |
db = firestore.client()
|
40 |
|
41 |
+
def log_message_to_firestore(input_message, db_info, temperature, response_text):
|
42 |
doc_ref = db.collection('logs').document()
|
43 |
log_data = {
|
44 |
'timestamp': firestore.SERVER_TIMESTAMP,
|
45 |
'temperature': temperature,
|
46 |
'db_info': db_info,
|
47 |
'input': input_message,
|
48 |
+
'output': response_text,
|
49 |
}
|
50 |
doc_ref.set(log_data)
|
51 |
+
|
52 |
+
rated_outputs = set() # set to store already rated outputs
|
53 |
+
|
54 |
+
def log_rating_to_firestore(input_message, db_info, temperature, response_text, rating):
|
55 |
+
global rated_outputs
|
56 |
+
output_id = f"{input_message} {db_info} {response_text} {temperature}"
|
57 |
+
|
58 |
+
if output_id in rated_outputs:
|
59 |
+
gr.Warning("You've already rated this output!")
|
60 |
+
return
|
61 |
+
if not input_message or not db_info or not response_text or not rating:
|
62 |
+
gr.Info("You haven't asked a question yet! Or the output box is empty.")
|
63 |
+
return
|
64 |
+
|
65 |
+
rated_outputs.add(output_id)
|
66 |
+
|
67 |
+
doc_ref = db.collection('ratings').document()
|
68 |
+
log_data = {
|
69 |
+
'timestamp': firestore.SERVER_TIMESTAMP,
|
70 |
+
'temperature': temperature,
|
71 |
+
'db_info': db_info,
|
72 |
+
'input': input_message,
|
73 |
+
'output': response_text,
|
74 |
+
'rating': rating,
|
75 |
+
}
|
76 |
+
doc_ref.set(log_data)
|
77 |
+
gr.Info("Thanks for your feedback!")
|
78 |
# End Firebase code
|
79 |
|
80 |
+
def copy_to_clipboard(text):
|
81 |
+
# Copy to clipboard
|
82 |
+
try:
|
83 |
+
pyperclip.copy(text)
|
84 |
+
gr.Info("Copied to clipboard!")
|
85 |
+
except Exception:
|
86 |
+
gr.Warning("Couldn't copy to clipboard :(")
|
87 |
+
|
88 |
def format(text):
|
89 |
# Split the text by "|", and get the last element in the list which should be the final query
|
90 |
try:
|
|
|
108 |
# Format the user's input message
|
109 |
messages = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to sql: {input_message} {db_info}\n\n### Response:\n\n"
|
110 |
|
111 |
+
url = os.getenv("KOBOLDCPP_API_URL")
|
112 |
stop_sequence = stop_sequence.split(",")
|
113 |
stop = ["###"] + stop_sequence
|
114 |
payload = {
|
|
|
141 |
|
142 |
if log:
|
143 |
# Log the request to Firestore
|
144 |
+
log_message_to_firestore(input_message, db_info, temperature, output if format_sql else response_text)
|
145 |
|
146 |
return output
|
147 |
|
|
|
157 |
header = gr.HTML("""
|
158 |
<h1 style="text-align: center">SQL Skeleton WizardCoder Demo</h1>
|
159 |
<h3 style="text-align: center">π·οΈβ οΈπ§ββοΈ Generate SQL queries from Natural Language π·οΈβ οΈπ§ββοΈ</h3>
|
160 |
+
<br>
|
161 |
+
<p style="font-size: 12px; text-align: center">β οΈ Should take 30-60s to generate. Please rate the response, it helps a lot.</p>
|
162 |
""")
|
163 |
|
164 |
output_box = gr.Code(label="Generated SQL", lines=2, interactive=True)
|
165 |
+
|
166 |
+
with gr.Row():
|
167 |
+
copy_button = gr.Button("π Copy SQL", variant="secondary")
|
168 |
+
rate_up = gr.Button("π", variant="secondary")
|
169 |
+
rate_down = gr.Button("π", variant="secondary")
|
170 |
+
|
171 |
input_text = gr.Textbox(lines=3, placeholder='Write your question here...', label='NL Input')
|
172 |
db_info = gr.Textbox(lines=4, placeholder='Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info')
|
173 |
format_sql = gr.Checkbox(label="Format SQL + Remove Skeleton", value=True, interactive=True)
|
174 |
|
175 |
+
with gr.Row():
|
176 |
+
run_button = gr.Button("Generate SQL", variant="primary")
|
177 |
+
clear_button = gr.ClearButton(variant="secondary")
|
178 |
|
179 |
with gr.Accordion("Options", open=False):
|
180 |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.2, step=0.1)
|
|
|
183 |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
|
184 |
stop_sequence = gr.Textbox(lines=1, value="Explanation,Note", label='Extra Stop Sequence')
|
185 |
|
|
|
186 |
info = gr.HTML(f"""
|
187 |
<p>π Leveraging the <a href='https://huggingface.co/{quantized_model}'><strong>4-bit GGML version</strong></a> of <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a> model.</p>
|
188 |
<p>π How it's made: <a href='https://huggingface.co/{initial_model}'><strong>{initial_model}</strong></a> was finetuned to create <a href='https://huggingface.co/{lora_model}'><strong>{lora_model}</strong></a>, then merged together to create <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a>.</p>
|
189 |
<p>π Fine-tuning was performed using QLoRA techniques on the <a href='https://huggingface.co/datasets/{dataset}'><strong>{dataset}</strong></a> dataset. You can view training metrics on the <a href='https://huggingface.co/{lora_model}'><strong>QLoRa adapter HF Repo</strong></a>.</p>
|
190 |
+
<p>π All inputs/outputs are logged to Firebase to see how the model is doing. You can also leave a rating for each generated SQL the model produces, which gets sent to the database as well.</a></p>
|
191 |
""")
|
192 |
|
193 |
examples = gr.Examples([
|
|
|
216 |
readme_content,
|
217 |
)
|
218 |
|
219 |
+
with gr.Accordion("Disabled Options:", open=False):
|
220 |
log = gr.Checkbox(label="Log to Firebase", value=True, interactive=False)
|
221 |
|
222 |
# When the button is clicked, call the generate function, inputs are taken from the UI elements, outputs are sent to outputs elements
|
223 |
run_button.click(fn=generate, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql, stop_sequence, log], outputs=output_box, api_name="txt2sql")
|
224 |
+
copy_button.click(fn=copy_to_clipboard, inputs=[output_box])
|
225 |
+
clear_button.add([input_text, db_info, output_box])
|
226 |
+
|
227 |
+
# Firebase code - for rating the generated SQL (remove if you don't want to use Firebase)
|
228 |
+
rate_up.click(fn=log_rating_to_firestore, inputs=[input_text, db_info, temperature, output_box, rate_up])
|
229 |
+
rate_down.click(fn=log_rating_to_firestore, inputs=[input_text, db_info, temperature, output_box, rate_down])
|
230 |
|
231 |
demo.queue(concurrency_count=1, max_size=20).launch(debug=True)
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ scipy
|
|
8 |
transformers
|
9 |
accelerate
|
10 |
sqlparse
|
11 |
-
firebase_admin
|
|
|
|
8 |
transformers
|
9 |
accelerate
|
10 |
sqlparse
|
11 |
+
firebase_admin
|
12 |
+
pyperclip
|