pp542-0965
commited on
Commit
·
c52a50c
1
Parent(s):
d3cf308
Add gradio app
Browse files
app.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from peft import PeftModel
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
9 |
+
|
10 |
+
|
11 |
+
def load_model_tokenizer():
|
12 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560)
|
13 |
+
model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False)
|
14 |
+
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560)
|
16 |
+
|
17 |
+
return model, tokenizer
|
18 |
+
|
19 |
+
|
20 |
+
model, tokenizer = load_model_tokenizer()
|
21 |
+
|
22 |
+
|
23 |
+
def create_prompt(schemas, question):
|
24 |
+
prompt = [
|
25 |
+
{
|
26 |
+
'role': 'system',
|
27 |
+
'content': """\
|
28 |
+
You are an expert SQL Query Writer.
|
29 |
+
Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer.
|
30 |
+
Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas.
|
31 |
+
|
32 |
+
Remember that you should place all your reasoning between <reason> and </reason> tags.
|
33 |
+
Also, you should provide your solution between <answer> and </answer> tags.
|
34 |
+
|
35 |
+
An example generation is as follows:
|
36 |
+
<reason>
|
37 |
+
This is a sample reasoning that solves the question based on the schema.
|
38 |
+
</reason>
|
39 |
+
<answer>
|
40 |
+
SELECT
|
41 |
+
COLUMN
|
42 |
+
FROM TABLE_NAME
|
43 |
+
WHERE
|
44 |
+
CONDITION
|
45 |
+
</answer>"""
|
46 |
+
},
|
47 |
+
{
|
48 |
+
'role': 'user',
|
49 |
+
'content': f"""\
|
50 |
+
SCHEMAS:
|
51 |
+
---------------
|
52 |
+
|
53 |
+
{schemas}
|
54 |
+
|
55 |
+
---------------
|
56 |
+
|
57 |
+
QUESTION: "{question}"\
|
58 |
+
"""
|
59 |
+
}
|
60 |
+
]
|
61 |
+
|
62 |
+
return prompt
|
63 |
+
|
64 |
+
|
65 |
+
def extract_answer(gen_output):
|
66 |
+
answer_start_token = "<answer>"
|
67 |
+
answer_end_token = "</answer>"
|
68 |
+
answer_match_format = re.compile(rf"{answer_start_token}(.+?){answer_end_token}", flags = re.MULTILINE | re.DOTALL | re.IGNORECASE)
|
69 |
+
|
70 |
+
answer_match = answer_match_format.search(gen_output)
|
71 |
+
|
72 |
+
final_answer = None
|
73 |
+
|
74 |
+
if answer_match is not None:
|
75 |
+
final_answer = answer_match.group(1)
|
76 |
+
|
77 |
+
return final_answer
|
78 |
+
|
79 |
+
|
80 |
+
def response(user_schemas, user_question):
|
81 |
+
user_prompt = create_prompt(user_schemas, user_question)
|
82 |
+
|
83 |
+
inputs = tokenizer.apply_chat_template(user_prompt,
|
84 |
+
tokenize=True,
|
85 |
+
add_generation_prompt=True,
|
86 |
+
return_dict=True,
|
87 |
+
return_tensors="pt")
|
88 |
+
|
89 |
+
with torch.inference_mode():
|
90 |
+
outputs = model.generate(**inputs, max_new_tokens=1024)
|
91 |
+
|
92 |
+
outputs = tokenizer.batch_decode(outputs)
|
93 |
+
output = outputs[0].split("<|im_start|>assistant")[-1]
|
94 |
+
|
95 |
+
final_answer = extract_answer(output)
|
96 |
+
|
97 |
+
return output + "\n\n" + "="*20 + "\n\nFinal Answer: \n" + final_answer
|
98 |
+
|
99 |
+
|
100 |
+
desc="""
|
101 |
+
Please use the "Table Schemas" field to provide the required schemas to to generate the SQL Query for - separated by new lines.
|
102 |
+
Eg. CREATE TABLE demographic (
|
103 |
+
subject_id text,
|
104 |
+
admission_type text,
|
105 |
+
hadm_id text)
|
106 |
+
|
107 |
+
CREATE TABLE diagnoses (
|
108 |
+
subject_id text,
|
109 |
+
hadm_id text)
|
110 |
+
|
111 |
+
Finally, use the "Question" field to provide the relevant question to be answered based on the provided schemas.
|
112 |
+
Eg. How many patients whose admission type is emergency.
|
113 |
+
"""
|
114 |
+
|
115 |
+
demo = gr.Interface(
|
116 |
+
fn=response,
|
117 |
+
inputs=[gr.Textbox(label="Table Schemas",
|
118 |
+
placeholder="Expected to have CREATE TABLE statements with datatypes separated by new lines"),
|
119 |
+
gr.Textbox(label="Question",
|
120 |
+
placeholder="Eg. How many patients whose admission type is emergency")
|
121 |
+
],
|
122 |
+
outputs=gr.Textbox(label="Generated SQL Query with reasoning"),
|
123 |
+
title="SQL Query Generator trained with GRPO to elicit reasoning",
|
124 |
+
description=desc
|
125 |
+
)
|
126 |
+
|
127 |
+
demo.launch()
|