pp542-0965 commited on
Commit
c52a50c
·
1 Parent(s): d3cf308

Add gradio app

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+
5
+ import gradio as gr
6
+
7
+ from peft import PeftModel
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+
10
+
11
+ def load_model_tokenizer():
12
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560)
13
+ model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False)
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560)
16
+
17
+ return model, tokenizer
18
+
19
+
20
+ model, tokenizer = load_model_tokenizer()
21
+
22
+
23
+ def create_prompt(schemas, question):
24
+ prompt = [
25
+ {
26
+ 'role': 'system',
27
+ 'content': """\
28
+ You are an expert SQL Query Writer.
29
+ Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer.
30
+ Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas.
31
+
32
+ Remember that you should place all your reasoning between <reason> and </reason> tags.
33
+ Also, you should provide your solution between <answer> and </answer> tags.
34
+
35
+ An example generation is as follows:
36
+ <reason>
37
+ This is a sample reasoning that solves the question based on the schema.
38
+ </reason>
39
+ <answer>
40
+ SELECT
41
+ COLUMN
42
+ FROM TABLE_NAME
43
+ WHERE
44
+ CONDITION
45
+ </answer>"""
46
+ },
47
+ {
48
+ 'role': 'user',
49
+ 'content': f"""\
50
+ SCHEMAS:
51
+ ---------------
52
+
53
+ {schemas}
54
+
55
+ ---------------
56
+
57
+ QUESTION: "{question}"\
58
+ """
59
+ }
60
+ ]
61
+
62
+ return prompt
63
+
64
+
65
+ def extract_answer(gen_output):
66
+ answer_start_token = "<answer>"
67
+ answer_end_token = "</answer>"
68
+ answer_match_format = re.compile(rf"{answer_start_token}(.+?){answer_end_token}", flags = re.MULTILINE | re.DOTALL | re.IGNORECASE)
69
+
70
+ answer_match = answer_match_format.search(gen_output)
71
+
72
+ final_answer = None
73
+
74
+ if answer_match is not None:
75
+ final_answer = answer_match.group(1)
76
+
77
+ return final_answer
78
+
79
+
80
+ def response(user_schemas, user_question):
81
+ user_prompt = create_prompt(user_schemas, user_question)
82
+
83
+ inputs = tokenizer.apply_chat_template(user_prompt,
84
+ tokenize=True,
85
+ add_generation_prompt=True,
86
+ return_dict=True,
87
+ return_tensors="pt")
88
+
89
+ with torch.inference_mode():
90
+ outputs = model.generate(**inputs, max_new_tokens=1024)
91
+
92
+ outputs = tokenizer.batch_decode(outputs)
93
+ output = outputs[0].split("<|im_start|>assistant")[-1]
94
+
95
+ final_answer = extract_answer(output)
96
+
97
+ return output + "\n\n" + "="*20 + "\n\nFinal Answer: \n" + final_answer
98
+
99
+
100
+ desc="""
101
+ Please use the "Table Schemas" field to provide the required schemas to to generate the SQL Query for - separated by new lines.
102
+ Eg. CREATE TABLE demographic (
103
+ subject_id text,
104
+ admission_type text,
105
+ hadm_id text)
106
+
107
+ CREATE TABLE diagnoses (
108
+ subject_id text,
109
+ hadm_id text)
110
+
111
+ Finally, use the "Question" field to provide the relevant question to be answered based on the provided schemas.
112
+ Eg. How many patients whose admission type is emergency.
113
+ """
114
+
115
+ demo = gr.Interface(
116
+ fn=response,
117
+ inputs=[gr.Textbox(label="Table Schemas",
118
+ placeholder="Expected to have CREATE TABLE statements with datatypes separated by new lines"),
119
+ gr.Textbox(label="Question",
120
+ placeholder="Eg. How many patients whose admission type is emergency")
121
+ ],
122
+ outputs=gr.Textbox(label="Generated SQL Query with reasoning"),
123
+ title="SQL Query Generator trained with GRPO to elicit reasoning",
124
+ description=desc
125
+ )
126
+
127
+ demo.launch()