simone-papicchio franceth commited on
Commit
af2b1fd
·
verified ·
1 Parent(s): c1258af

Fix prompts buttons, and NL2SQL bug (#24)

Browse files

- Fix prompts buttons, and NL2SQL bug (b8f53f4140ce72bf889c039fa072989834ee8d73)


Co-authored-by: Francesco Giannuzzo <[email protected]>

Files changed (3) hide show
  1. app.py +48 -63
  2. utilities.py +10 -4
  3. utils_get_db_tables_info.py +31 -3
app.py CHANGED
@@ -509,9 +509,13 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
509
  selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
510
  input_data['models'] = selected_models
511
  button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
512
- return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state), gr.update(interactive=button_state)
513
 
514
  # Add the Textbox to the interface
 
 
 
 
515
  prompt = gr.TextArea(
516
  label="Customise the prompt for selected models here or leave the default one.",
517
  placeholder=prompt_default,
@@ -522,17 +526,20 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
522
 
523
  # Submit button (initially disabled)
524
  with gr.Row():
525
- submit_models_button = gr.Button("Submit Models for NL2SQL task", interactive=False)
526
- submit_models_button_tqa = gr.Button("Submit Models for TQA task", interactive=False)
527
 
528
  def check_prompt(prompt):
529
  #TODO
530
  missing_elements = []
531
  if(prompt==""):
532
- input_data["prompt"] = prompt_default
 
 
 
 
533
  button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
534
  else:
535
- input_data["prompt"]=prompt
536
  if "{db_schema}" not in prompt:
537
  missing_elements.append("{db_schema}")
538
  if "{question}" not in prompt:
@@ -543,21 +550,21 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
543
  value=f"<div style='text-align: center; font-size: 18px; font-weight: bold;'>"
544
  f"❌ Missing {', '.join(missing_elements)} in the prompt ❌</div>",
545
  visible=True
546
- ), gr.update(interactive=button_state)
547
- return gr.update(visible=False), gr.update(interactive=button_state)
548
 
549
- prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, submit_models_button_tqa])
550
  # Link checkboxes to selection events
551
  for checkbox in model_checkboxes:
552
  checkbox.change(
553
  fn=get_selected_models,
554
  inputs=model_checkboxes,
555
- outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
556
  )
557
  prompt.change(
558
  fn=get_selected_models,
559
  inputs=model_checkboxes,
560
- outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
561
  )
562
 
563
  submit_models_button.click(
@@ -566,11 +573,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
566
  outputs=[selected_models_output, select_model_acc, qatch_acc]
567
  )
568
 
569
- submit_models_button_tqa.click(
570
- fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
571
- inputs=model_checkboxes,
572
- outputs=[selected_models_output, select_model_acc, qatch_acc]
573
- )
574
  def change_flag():
575
  global flag_TQA
576
  flag_TQA = True
@@ -579,8 +581,14 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
579
  global flag_TQA
580
  flag_TQA = False
581
 
582
- submit_models_button.click(fn = dis_flag, inputs=[], outputs=[])
583
- submit_models_button_tqa.click(fn = change_flag, inputs=[], outputs=[])
 
 
 
 
 
 
584
 
585
  def enable_disable(enable):
586
  return (
@@ -592,7 +600,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
592
  gr.update(interactive=enable),
593
  gr.update(interactive=enable),
594
  *[gr.update(interactive=enable) for _ in table_outputs],
595
- gr.update(interactive=enable),
596
  gr.update(interactive=enable)
597
  )
598
 
@@ -610,24 +617,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
610
  default_checkbox,
611
  table_selector,
612
  *table_outputs,
613
- open_model_selection,
614
- submit_models_button_tqa
615
- ]
616
- )
617
- submit_models_button_tqa.click(
618
- fn=enable_disable,
619
- inputs=[gr.State(False)],
620
- outputs=[
621
- *model_checkboxes,
622
- submit_models_button,
623
- preview_output,
624
- submit_button,
625
- file_input,
626
- default_checkbox,
627
- table_selector,
628
- *table_outputs,
629
- open_model_selection,
630
- submit_models_button_tqa
631
  ]
632
  )
633
 
@@ -645,8 +635,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
645
  default_checkbox,
646
  table_selector,
647
  *table_outputs,
648
- open_model_selection,
649
- submit_models_button_tqa
650
  ]
651
  )
652
 
@@ -749,13 +738,28 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
749
  yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
750
 
751
  else:
 
752
  orchestrator_generator = OrchestratorGenerator()
753
  target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
754
 
755
  #create target_df[target_answer]
756
  if flag_TQA :
757
- if (input_data["prompt"] == prompt_default):
758
- input_data["prompt"] = prompt_default_tqa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
  target_df = us.extract_answer(target_df)
760
 
761
  predictor = ModelPrediction()
@@ -766,6 +770,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
766
  count=0
767
  for index, row in target_df.iterrows():
768
  if (reset_flag == False):
 
769
  percent_complete = round(((index+1) / len(target_df)) * 100, 2)
770
  load_text = f"{generate_loading_text(percent_complete)}"
771
 
@@ -780,7 +785,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
780
  #samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
781
  model_to_send = None if not flag_TQA else model
782
 
783
-
784
  db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
785
  db_id = input_data["db_name"],
786
  base_path = input_data["data_path"],
@@ -806,11 +810,11 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
806
  prompt=f"{prompt_to_send}",
807
  task=task
808
  )
 
809
  prediction = response['response_parsed']
810
  price = response['cost']
811
  answer = response['response']
812
 
813
- end_time = time.time()
814
  if flag_TQA:
815
  task_string = "Answer"
816
  else:
@@ -857,6 +861,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
857
  evaluator = OrchestratorEvaluator()
858
 
859
  for model in input_data["models"]:
 
860
  if not flag_TQA:
861
  metrics_df_model = evaluator.evaluate_df(
862
  df=predictions_dict[model],
@@ -920,11 +925,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
920
  inputs=[],
921
  outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
922
  )
923
- submit_models_button_tqa.click(
924
- change_tab,
925
- inputs=[],
926
- outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
927
- )
928
 
929
  selected_models_display = gr.JSON(label="Final input data", visible=False)
930
  metrics_df = gr.DataFrame(visible=False)
@@ -936,20 +936,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
936
  outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
937
  )
938
 
939
- submit_models_button_tqa.click(
940
- fn=qatch_flow_nl_sql,
941
- inputs=[],
942
- outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
943
- )
944
-
945
  submit_models_button.click(
946
  fn=lambda: gr.update(value=input_data),
947
  outputs=[selected_models_display]
948
  )
949
- submit_models_button_tqa.click(
950
- fn=lambda: gr.update(value=input_data),
951
- outputs=[selected_models_display]
952
- )
953
 
954
  # Works for METRICS
955
  metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
@@ -972,10 +962,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
972
  fn=lambda: gr.update(visible=False),
973
  outputs=[download_metrics]
974
  )
975
- submit_models_button_tqa.click(
976
- fn=lambda: gr.update(visible=False),
977
- outputs=[download_metrics]
978
- )
979
 
980
  def refresh():
981
  global reset_flag
@@ -1007,8 +993,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1007
  default_checkbox,
1008
  table_selector,
1009
  *table_outputs,
1010
- open_model_selection,
1011
- submit_models_button_tqa
1012
  ]
1013
  )
1014
 
 
509
  selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
510
  input_data['models'] = selected_models
511
  button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
512
+ return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
513
 
514
  # Add the Textbox to the interface
515
+ with gr.Row():
516
+ button_prompt_nlsql = gr.Button("Choose NL2SQL task")
517
+ button_prompt_tqa = gr.Button("Choose TQA task")
518
+
519
  prompt = gr.TextArea(
520
  label="Customise the prompt for selected models here or leave the default one.",
521
  placeholder=prompt_default,
 
526
 
527
  # Submit button (initially disabled)
528
  with gr.Row():
529
+ submit_models_button = gr.Button("Submit Models", interactive=False)
 
530
 
531
  def check_prompt(prompt):
532
  #TODO
533
  missing_elements = []
534
  if(prompt==""):
535
+ global flag_TQA
536
+ if not flag_TQA:
537
+ input_data["prompt"] = prompt_default
538
+ else:
539
+ input_data["prompt"] = prompt_default_tqa
540
  button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
541
  else:
542
+ input_data["prompt"] = prompt
543
  if "{db_schema}" not in prompt:
544
  missing_elements.append("{db_schema}")
545
  if "{question}" not in prompt:
 
550
  value=f"<div style='text-align: center; font-size: 18px; font-weight: bold;'>"
551
  f"❌ Missing {', '.join(missing_elements)} in the prompt ❌</div>",
552
  visible=True
553
+ ), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
554
+ return gr.update(visible=False), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
555
 
556
+ prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
557
  # Link checkboxes to selection events
558
  for checkbox in model_checkboxes:
559
  checkbox.change(
560
  fn=get_selected_models,
561
  inputs=model_checkboxes,
562
+ outputs=[selected_models_output, select_model_acc, submit_models_button]
563
  )
564
  prompt.change(
565
  fn=get_selected_models,
566
  inputs=model_checkboxes,
567
+ outputs=[selected_models_output, select_model_acc, submit_models_button]
568
  )
569
 
570
  submit_models_button.click(
 
573
  outputs=[selected_models_output, select_model_acc, qatch_acc]
574
  )
575
 
 
 
 
 
 
576
  def change_flag():
577
  global flag_TQA
578
  flag_TQA = True
 
581
  global flag_TQA
582
  flag_TQA = False
583
 
584
+ button_prompt_tqa.click(fn = change_flag, inputs=[], outputs=[])
585
+
586
+ button_prompt_nlsql.click(fn = dis_flag, inputs=[], outputs=[])
587
+
588
+ button_prompt_tqa.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
589
+
590
+ button_prompt_nlsql.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
591
+
592
 
593
  def enable_disable(enable):
594
  return (
 
600
  gr.update(interactive=enable),
601
  gr.update(interactive=enable),
602
  *[gr.update(interactive=enable) for _ in table_outputs],
 
603
  gr.update(interactive=enable)
604
  )
605
 
 
617
  default_checkbox,
618
  table_selector,
619
  *table_outputs,
620
+ open_model_selection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  ]
622
  )
623
 
 
635
  default_checkbox,
636
  table_selector,
637
  *table_outputs,
638
+ open_model_selection
 
639
  ]
640
  )
641
 
 
738
  yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
739
 
740
  else:
741
+ global flag_TQA
742
  orchestrator_generator = OrchestratorGenerator()
743
  target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
744
 
745
  #create target_df[target_answer]
746
  if flag_TQA :
747
+ # if (input_data["prompt"] == prompt_default):
748
+ # input_data["prompt"] = prompt_default_tqa
749
+
750
+ target_df['db_schema'] = target_df.apply(
751
+ lambda row: utils_get_db_tables_info.utils_extract_db_schema_as_string(
752
+ db_id=input_data["db_name"],
753
+ base_path=input_data["data_path"],
754
+ normalize=False,
755
+ sql=row["query"],
756
+ get_insert_into=True,
757
+ model=None,
758
+ prompt=input_data["prompt"].format(question=row["question"], db_schema="")
759
+ ),
760
+ axis=1
761
+ )
762
+
763
  target_df = us.extract_answer(target_df)
764
 
765
  predictor = ModelPrediction()
 
770
  count=0
771
  for index, row in target_df.iterrows():
772
  if (reset_flag == False):
773
+ global flag_TQA
774
  percent_complete = round(((index+1) / len(target_df)) * 100, 2)
775
  load_text = f"{generate_loading_text(percent_complete)}"
776
 
 
785
  #samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
786
  model_to_send = None if not flag_TQA else model
787
 
 
788
  db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
789
  db_id = input_data["db_name"],
790
  base_path = input_data["data_path"],
 
810
  prompt=f"{prompt_to_send}",
811
  task=task
812
  )
813
+ end_time = time.time()
814
  prediction = response['response_parsed']
815
  price = response['cost']
816
  answer = response['response']
817
 
 
818
  if flag_TQA:
819
  task_string = "Answer"
820
  else:
 
861
  evaluator = OrchestratorEvaluator()
862
 
863
  for model in input_data["models"]:
864
+ global flag_TQA
865
  if not flag_TQA:
866
  metrics_df_model = evaluator.evaluate_df(
867
  df=predictions_dict[model],
 
925
  inputs=[],
926
  outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
927
  )
 
 
 
 
 
928
 
929
  selected_models_display = gr.JSON(label="Final input data", visible=False)
930
  metrics_df = gr.DataFrame(visible=False)
 
936
  outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
937
  )
938
 
 
 
 
 
 
 
939
  submit_models_button.click(
940
  fn=lambda: gr.update(value=input_data),
941
  outputs=[selected_models_display]
942
  )
 
 
 
 
943
 
944
  # Works for METRICS
945
  metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
 
962
  fn=lambda: gr.update(visible=False),
963
  outputs=[download_metrics]
964
  )
 
 
 
 
965
 
966
  def refresh():
967
  global reset_flag
 
993
  default_checkbox,
994
  table_selector,
995
  *table_outputs,
996
+ open_model_selection
 
997
  ]
998
  )
999
 
utilities.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  from qatch.connectors.sqlite_connector import SqliteConnector
9
  from qatch.evaluate_dataset.metrics_evaluators import CellPrecision, CellRecall, ExecutionAccuracy, TupleCardinality, TupleConstraint, TupleOrder, ValidEfficiencyScore
10
  import qatch.evaluate_dataset.orchestrator_evaluator as eva
 
11
  #import tiktoken
12
  from transformers import AutoTokenizer
13
 
@@ -151,11 +152,16 @@ def extract_answer(df):
151
  answers = []
152
  for _, row in df.iterrows():
153
  query = row["query"]
154
- db_path = row["db_path"]
155
- try:
156
- conn = SqliteConnector(relative_db_path = db_path , db_name= "db")
157
- answer = eva._utils_run_query_if_str(query, conn)
 
 
 
 
158
  answers.append(answer)
 
159
  except Exception as e:
160
  answers.append(f"Error: {e}")
161
 
 
8
  from qatch.connectors.sqlite_connector import SqliteConnector
9
  from qatch.evaluate_dataset.metrics_evaluators import CellPrecision, CellRecall, ExecutionAccuracy, TupleCardinality, TupleConstraint, TupleOrder, ValidEfficiencyScore
10
  import qatch.evaluate_dataset.orchestrator_evaluator as eva
11
+ import utils_get_db_tables_info
12
  #import tiktoken
13
  from transformers import AutoTokenizer
14
 
 
152
  answers = []
153
  for _, row in df.iterrows():
154
  query = row["query"]
155
+ db_schema = row["db_schema"]
156
+ #db_path = row["db_path"]
157
+ try:
158
+ conn = utils_get_db_tables_info.create_db_temp(db_schema)
159
+
160
+ result = pd.read_sql_query(query, conn)
161
+ answer = result.values.tolist() # Convert the DataFrame to a list of lists
162
+
163
  answers.append(answer)
164
+ conn.close()
165
  except Exception as e:
166
  answers.append(f"Error: {e}")
167
 
utils_get_db_tables_info.py CHANGED
@@ -49,11 +49,15 @@ def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | No
49
  tables = [tbl[0] for tbl in cursor.fetchall()]
50
 
51
  for table in tables:
 
52
  # Retrieve the CREATE TABLE statement for each table
53
  cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
54
  create_table_stmt = cursor.fetchone()
55
  if create_table_stmt:
56
- entries.append(create_table_stmt[0])
 
 
 
57
 
58
  if get_insert_into:
59
  # Retrieve all data from the table
@@ -70,9 +74,10 @@ def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | No
70
  for row in rows[:max_len]:
71
  values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
72
  insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
73
- entries.append(insert_stmt)
74
 
75
- if model != None : entries = us.crop_entries_per_token(entries, model, prompt)
 
76
 
77
  return entries
78
 
@@ -112,3 +117,26 @@ def _combine_schema_entries(schema_entries, normalize):
112
  )
113
  for entry in schema_entries
114
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  tables = [tbl[0] for tbl in cursor.fetchall()]
50
 
51
  for table in tables:
52
+ entries_per_table = []
53
  # Retrieve the CREATE TABLE statement for each table
54
  cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
55
  create_table_stmt = cursor.fetchone()
56
  if create_table_stmt:
57
+ stmt = create_table_stmt[0].strip()
58
+ if not stmt.endswith(';'):
59
+ stmt += ';'
60
+ entries_per_table.append(stmt)
61
 
62
  if get_insert_into:
63
  # Retrieve all data from the table
 
74
  for row in rows[:max_len]:
75
  values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
76
  insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
77
+ entries_per_table.append(insert_stmt)
78
 
79
+ if model != None : entries_per_table = us.crop_entries_per_token(entries_per_table, model, prompt)
80
+ entries.extend(entries_per_table)
81
 
82
  return entries
83
 
 
117
  )
118
  for entry in schema_entries
119
  )
120
+
121
+
122
+ def create_db_temp(schema_sql: str) -> sqlite3.Connection:
123
+ """
124
+ Creates a temporary SQLite database in memory by executing the provided SQL schema.
125
+
126
+ Args:
127
+ schema_sql (str): The SQL code containing CREATE TABLE and INSERT INTO.
128
+
129
+ Returns:
130
+ sqlite3.Connection: Connection object to the temporary database.
131
+ """
132
+ conn = sqlite3.connect(':memory:')
133
+ cursor = conn.cursor()
134
+
135
+ try:
136
+ cursor.executescript(schema_sql)
137
+ conn.commit()
138
+ except sqlite3.Error as e:
139
+ conn.close()
140
+ raise
141
+
142
+ return conn