Spaces:
Sleeping
Sleeping
Fix prompts buttons, and NL2SQL bug (#24)
Browse files- Fix prompts buttons, and NL2SQL bug (b8f53f4140ce72bf889c039fa072989834ee8d73)
Co-authored-by: Francesco Giannuzzo <[email protected]>
- app.py +48 -63
- utilities.py +10 -4
- utils_get_db_tables_info.py +31 -3
app.py
CHANGED
@@ -509,9 +509,13 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
509 |
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
|
510 |
input_data['models'] = selected_models
|
511 |
button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
512 |
-
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
|
513 |
|
514 |
# Add the Textbox to the interface
|
|
|
|
|
|
|
|
|
515 |
prompt = gr.TextArea(
|
516 |
label="Customise the prompt for selected models here or leave the default one.",
|
517 |
placeholder=prompt_default,
|
@@ -522,17 +526,20 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
522 |
|
523 |
# Submit button (initially disabled)
|
524 |
with gr.Row():
|
525 |
-
submit_models_button = gr.Button("Submit Models
|
526 |
-
submit_models_button_tqa = gr.Button("Submit Models for TQA task", interactive=False)
|
527 |
|
528 |
def check_prompt(prompt):
|
529 |
#TODO
|
530 |
missing_elements = []
|
531 |
if(prompt==""):
|
532 |
-
|
|
|
|
|
|
|
|
|
533 |
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
534 |
else:
|
535 |
-
input_data["prompt"]=prompt
|
536 |
if "{db_schema}" not in prompt:
|
537 |
missing_elements.append("{db_schema}")
|
538 |
if "{question}" not in prompt:
|
@@ -543,21 +550,21 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
543 |
value=f"<div style='text-align: center; font-size: 18px; font-weight: bold;'>"
|
544 |
f"❌ Missing {', '.join(missing_elements)} in the prompt ❌</div>",
|
545 |
visible=True
|
546 |
-
), gr.update(interactive=button_state)
|
547 |
-
return gr.update(visible=False),
|
548 |
|
549 |
-
prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button
|
550 |
# Link checkboxes to selection events
|
551 |
for checkbox in model_checkboxes:
|
552 |
checkbox.change(
|
553 |
fn=get_selected_models,
|
554 |
inputs=model_checkboxes,
|
555 |
-
outputs=[selected_models_output, select_model_acc, submit_models_button
|
556 |
)
|
557 |
prompt.change(
|
558 |
fn=get_selected_models,
|
559 |
inputs=model_checkboxes,
|
560 |
-
outputs=[selected_models_output, select_model_acc, submit_models_button
|
561 |
)
|
562 |
|
563 |
submit_models_button.click(
|
@@ -566,11 +573,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
566 |
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
567 |
)
|
568 |
|
569 |
-
submit_models_button_tqa.click(
|
570 |
-
fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
|
571 |
-
inputs=model_checkboxes,
|
572 |
-
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
573 |
-
)
|
574 |
def change_flag():
|
575 |
global flag_TQA
|
576 |
flag_TQA = True
|
@@ -579,8 +581,14 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
579 |
global flag_TQA
|
580 |
flag_TQA = False
|
581 |
|
582 |
-
|
583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
584 |
|
585 |
def enable_disable(enable):
|
586 |
return (
|
@@ -592,7 +600,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
592 |
gr.update(interactive=enable),
|
593 |
gr.update(interactive=enable),
|
594 |
*[gr.update(interactive=enable) for _ in table_outputs],
|
595 |
-
gr.update(interactive=enable),
|
596 |
gr.update(interactive=enable)
|
597 |
)
|
598 |
|
@@ -610,24 +617,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
610 |
default_checkbox,
|
611 |
table_selector,
|
612 |
*table_outputs,
|
613 |
-
open_model_selection
|
614 |
-
submit_models_button_tqa
|
615 |
-
]
|
616 |
-
)
|
617 |
-
submit_models_button_tqa.click(
|
618 |
-
fn=enable_disable,
|
619 |
-
inputs=[gr.State(False)],
|
620 |
-
outputs=[
|
621 |
-
*model_checkboxes,
|
622 |
-
submit_models_button,
|
623 |
-
preview_output,
|
624 |
-
submit_button,
|
625 |
-
file_input,
|
626 |
-
default_checkbox,
|
627 |
-
table_selector,
|
628 |
-
*table_outputs,
|
629 |
-
open_model_selection,
|
630 |
-
submit_models_button_tqa
|
631 |
]
|
632 |
)
|
633 |
|
@@ -645,8 +635,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
645 |
default_checkbox,
|
646 |
table_selector,
|
647 |
*table_outputs,
|
648 |
-
open_model_selection
|
649 |
-
submit_models_button_tqa
|
650 |
]
|
651 |
)
|
652 |
|
@@ -749,13 +738,28 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
749 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
750 |
|
751 |
else:
|
|
|
752 |
orchestrator_generator = OrchestratorGenerator()
|
753 |
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
|
754 |
|
755 |
#create target_df[target_answer]
|
756 |
if flag_TQA :
|
757 |
-
if (input_data["prompt"] == prompt_default):
|
758 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
759 |
target_df = us.extract_answer(target_df)
|
760 |
|
761 |
predictor = ModelPrediction()
|
@@ -766,6 +770,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
766 |
count=0
|
767 |
for index, row in target_df.iterrows():
|
768 |
if (reset_flag == False):
|
|
|
769 |
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
|
770 |
load_text = f"{generate_loading_text(percent_complete)}"
|
771 |
|
@@ -780,7 +785,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
780 |
#samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
|
781 |
model_to_send = None if not flag_TQA else model
|
782 |
|
783 |
-
|
784 |
db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
785 |
db_id = input_data["db_name"],
|
786 |
base_path = input_data["data_path"],
|
@@ -806,11 +810,11 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
806 |
prompt=f"{prompt_to_send}",
|
807 |
task=task
|
808 |
)
|
|
|
809 |
prediction = response['response_parsed']
|
810 |
price = response['cost']
|
811 |
answer = response['response']
|
812 |
|
813 |
-
end_time = time.time()
|
814 |
if flag_TQA:
|
815 |
task_string = "Answer"
|
816 |
else:
|
@@ -857,6 +861,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
857 |
evaluator = OrchestratorEvaluator()
|
858 |
|
859 |
for model in input_data["models"]:
|
|
|
860 |
if not flag_TQA:
|
861 |
metrics_df_model = evaluator.evaluate_df(
|
862 |
df=predictions_dict[model],
|
@@ -920,11 +925,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
920 |
inputs=[],
|
921 |
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
922 |
)
|
923 |
-
submit_models_button_tqa.click(
|
924 |
-
change_tab,
|
925 |
-
inputs=[],
|
926 |
-
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
927 |
-
)
|
928 |
|
929 |
selected_models_display = gr.JSON(label="Final input data", visible=False)
|
930 |
metrics_df = gr.DataFrame(visible=False)
|
@@ -936,20 +936,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
936 |
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
937 |
)
|
938 |
|
939 |
-
submit_models_button_tqa.click(
|
940 |
-
fn=qatch_flow_nl_sql,
|
941 |
-
inputs=[],
|
942 |
-
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
943 |
-
)
|
944 |
-
|
945 |
submit_models_button.click(
|
946 |
fn=lambda: gr.update(value=input_data),
|
947 |
outputs=[selected_models_display]
|
948 |
)
|
949 |
-
submit_models_button_tqa.click(
|
950 |
-
fn=lambda: gr.update(value=input_data),
|
951 |
-
outputs=[selected_models_display]
|
952 |
-
)
|
953 |
|
954 |
# Works for METRICS
|
955 |
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
|
@@ -972,10 +962,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
972 |
fn=lambda: gr.update(visible=False),
|
973 |
outputs=[download_metrics]
|
974 |
)
|
975 |
-
submit_models_button_tqa.click(
|
976 |
-
fn=lambda: gr.update(visible=False),
|
977 |
-
outputs=[download_metrics]
|
978 |
-
)
|
979 |
|
980 |
def refresh():
|
981 |
global reset_flag
|
@@ -1007,8 +993,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1007 |
default_checkbox,
|
1008 |
table_selector,
|
1009 |
*table_outputs,
|
1010 |
-
open_model_selection
|
1011 |
-
submit_models_button_tqa
|
1012 |
]
|
1013 |
)
|
1014 |
|
|
|
509 |
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
|
510 |
input_data['models'] = selected_models
|
511 |
button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
512 |
+
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
|
513 |
|
514 |
# Add the Textbox to the interface
|
515 |
+
with gr.Row():
|
516 |
+
button_prompt_nlsql = gr.Button("Choose NL2SQL task")
|
517 |
+
button_prompt_tqa = gr.Button("Choose TQA task")
|
518 |
+
|
519 |
prompt = gr.TextArea(
|
520 |
label="Customise the prompt for selected models here or leave the default one.",
|
521 |
placeholder=prompt_default,
|
|
|
526 |
|
527 |
# Submit button (initially disabled)
|
528 |
with gr.Row():
|
529 |
+
submit_models_button = gr.Button("Submit Models", interactive=False)
|
|
|
530 |
|
531 |
def check_prompt(prompt):
|
532 |
#TODO
|
533 |
missing_elements = []
|
534 |
if(prompt==""):
|
535 |
+
global flag_TQA
|
536 |
+
if not flag_TQA:
|
537 |
+
input_data["prompt"] = prompt_default
|
538 |
+
else:
|
539 |
+
input_data["prompt"] = prompt_default_tqa
|
540 |
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
541 |
else:
|
542 |
+
input_data["prompt"] = prompt
|
543 |
if "{db_schema}" not in prompt:
|
544 |
missing_elements.append("{db_schema}")
|
545 |
if "{question}" not in prompt:
|
|
|
550 |
value=f"<div style='text-align: center; font-size: 18px; font-weight: bold;'>"
|
551 |
f"❌ Missing {', '.join(missing_elements)} in the prompt ❌</div>",
|
552 |
visible=True
|
553 |
+
), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
|
554 |
+
return gr.update(visible=False), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
|
555 |
|
556 |
+
prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
|
557 |
# Link checkboxes to selection events
|
558 |
for checkbox in model_checkboxes:
|
559 |
checkbox.change(
|
560 |
fn=get_selected_models,
|
561 |
inputs=model_checkboxes,
|
562 |
+
outputs=[selected_models_output, select_model_acc, submit_models_button]
|
563 |
)
|
564 |
prompt.change(
|
565 |
fn=get_selected_models,
|
566 |
inputs=model_checkboxes,
|
567 |
+
outputs=[selected_models_output, select_model_acc, submit_models_button]
|
568 |
)
|
569 |
|
570 |
submit_models_button.click(
|
|
|
573 |
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
574 |
)
|
575 |
|
|
|
|
|
|
|
|
|
|
|
576 |
def change_flag():
|
577 |
global flag_TQA
|
578 |
flag_TQA = True
|
|
|
581 |
global flag_TQA
|
582 |
flag_TQA = False
|
583 |
|
584 |
+
button_prompt_tqa.click(fn = change_flag, inputs=[], outputs=[])
|
585 |
+
|
586 |
+
button_prompt_nlsql.click(fn = dis_flag, inputs=[], outputs=[])
|
587 |
+
|
588 |
+
button_prompt_tqa.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
|
589 |
+
|
590 |
+
button_prompt_nlsql.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
|
591 |
+
|
592 |
|
593 |
def enable_disable(enable):
|
594 |
return (
|
|
|
600 |
gr.update(interactive=enable),
|
601 |
gr.update(interactive=enable),
|
602 |
*[gr.update(interactive=enable) for _ in table_outputs],
|
|
|
603 |
gr.update(interactive=enable)
|
604 |
)
|
605 |
|
|
|
617 |
default_checkbox,
|
618 |
table_selector,
|
619 |
*table_outputs,
|
620 |
+
open_model_selection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
]
|
622 |
)
|
623 |
|
|
|
635 |
default_checkbox,
|
636 |
table_selector,
|
637 |
*table_outputs,
|
638 |
+
open_model_selection
|
|
|
639 |
]
|
640 |
)
|
641 |
|
|
|
738 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
739 |
|
740 |
else:
|
741 |
+
global flag_TQA
|
742 |
orchestrator_generator = OrchestratorGenerator()
|
743 |
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
|
744 |
|
745 |
#create target_df[target_answer]
|
746 |
if flag_TQA :
|
747 |
+
# if (input_data["prompt"] == prompt_default):
|
748 |
+
# input_data["prompt"] = prompt_default_tqa
|
749 |
+
|
750 |
+
target_df['db_schema'] = target_df.apply(
|
751 |
+
lambda row: utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
752 |
+
db_id=input_data["db_name"],
|
753 |
+
base_path=input_data["data_path"],
|
754 |
+
normalize=False,
|
755 |
+
sql=row["query"],
|
756 |
+
get_insert_into=True,
|
757 |
+
model=None,
|
758 |
+
prompt=input_data["prompt"].format(question=row["question"], db_schema="")
|
759 |
+
),
|
760 |
+
axis=1
|
761 |
+
)
|
762 |
+
|
763 |
target_df = us.extract_answer(target_df)
|
764 |
|
765 |
predictor = ModelPrediction()
|
|
|
770 |
count=0
|
771 |
for index, row in target_df.iterrows():
|
772 |
if (reset_flag == False):
|
773 |
+
global flag_TQA
|
774 |
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
|
775 |
load_text = f"{generate_loading_text(percent_complete)}"
|
776 |
|
|
|
785 |
#samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
|
786 |
model_to_send = None if not flag_TQA else model
|
787 |
|
|
|
788 |
db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
789 |
db_id = input_data["db_name"],
|
790 |
base_path = input_data["data_path"],
|
|
|
810 |
prompt=f"{prompt_to_send}",
|
811 |
task=task
|
812 |
)
|
813 |
+
end_time = time.time()
|
814 |
prediction = response['response_parsed']
|
815 |
price = response['cost']
|
816 |
answer = response['response']
|
817 |
|
|
|
818 |
if flag_TQA:
|
819 |
task_string = "Answer"
|
820 |
else:
|
|
|
861 |
evaluator = OrchestratorEvaluator()
|
862 |
|
863 |
for model in input_data["models"]:
|
864 |
+
global flag_TQA
|
865 |
if not flag_TQA:
|
866 |
metrics_df_model = evaluator.evaluate_df(
|
867 |
df=predictions_dict[model],
|
|
|
925 |
inputs=[],
|
926 |
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
927 |
)
|
|
|
|
|
|
|
|
|
|
|
928 |
|
929 |
selected_models_display = gr.JSON(label="Final input data", visible=False)
|
930 |
metrics_df = gr.DataFrame(visible=False)
|
|
|
936 |
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
937 |
)
|
938 |
|
|
|
|
|
|
|
|
|
|
|
|
|
939 |
submit_models_button.click(
|
940 |
fn=lambda: gr.update(value=input_data),
|
941 |
outputs=[selected_models_display]
|
942 |
)
|
|
|
|
|
|
|
|
|
943 |
|
944 |
# Works for METRICS
|
945 |
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
|
|
|
962 |
fn=lambda: gr.update(visible=False),
|
963 |
outputs=[download_metrics]
|
964 |
)
|
|
|
|
|
|
|
|
|
965 |
|
966 |
def refresh():
|
967 |
global reset_flag
|
|
|
993 |
default_checkbox,
|
994 |
table_selector,
|
995 |
*table_outputs,
|
996 |
+
open_model_selection
|
|
|
997 |
]
|
998 |
)
|
999 |
|
utilities.py
CHANGED
@@ -8,6 +8,7 @@ import os
|
|
8 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
9 |
from qatch.evaluate_dataset.metrics_evaluators import CellPrecision, CellRecall, ExecutionAccuracy, TupleCardinality, TupleConstraint, TupleOrder, ValidEfficiencyScore
|
10 |
import qatch.evaluate_dataset.orchestrator_evaluator as eva
|
|
|
11 |
#import tiktoken
|
12 |
from transformers import AutoTokenizer
|
13 |
|
@@ -151,11 +152,16 @@ def extract_answer(df):
|
|
151 |
answers = []
|
152 |
for _, row in df.iterrows():
|
153 |
query = row["query"]
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
158 |
answers.append(answer)
|
|
|
159 |
except Exception as e:
|
160 |
answers.append(f"Error: {e}")
|
161 |
|
|
|
8 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
9 |
from qatch.evaluate_dataset.metrics_evaluators import CellPrecision, CellRecall, ExecutionAccuracy, TupleCardinality, TupleConstraint, TupleOrder, ValidEfficiencyScore
|
10 |
import qatch.evaluate_dataset.orchestrator_evaluator as eva
|
11 |
+
import utils_get_db_tables_info
|
12 |
#import tiktoken
|
13 |
from transformers import AutoTokenizer
|
14 |
|
|
|
152 |
answers = []
|
153 |
for _, row in df.iterrows():
|
154 |
query = row["query"]
|
155 |
+
db_schema = row["db_schema"]
|
156 |
+
#db_path = row["db_path"]
|
157 |
+
try:
|
158 |
+
conn = utils_get_db_tables_info.create_db_temp(db_schema)
|
159 |
+
|
160 |
+
result = pd.read_sql_query(query, conn)
|
161 |
+
answer = result.values.tolist() # Convert the DataFrame to a list of lists
|
162 |
+
|
163 |
answers.append(answer)
|
164 |
+
conn.close()
|
165 |
except Exception as e:
|
166 |
answers.append(f"Error: {e}")
|
167 |
|
utils_get_db_tables_info.py
CHANGED
@@ -49,11 +49,15 @@ def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | No
|
|
49 |
tables = [tbl[0] for tbl in cursor.fetchall()]
|
50 |
|
51 |
for table in tables:
|
|
|
52 |
# Retrieve the CREATE TABLE statement for each table
|
53 |
cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
|
54 |
create_table_stmt = cursor.fetchone()
|
55 |
if create_table_stmt:
|
56 |
-
|
|
|
|
|
|
|
57 |
|
58 |
if get_insert_into:
|
59 |
# Retrieve all data from the table
|
@@ -70,9 +74,10 @@ def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | No
|
|
70 |
for row in rows[:max_len]:
|
71 |
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
|
72 |
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
|
73 |
-
|
74 |
|
75 |
-
|
|
|
76 |
|
77 |
return entries
|
78 |
|
@@ -112,3 +117,26 @@ def _combine_schema_entries(schema_entries, normalize):
|
|
112 |
)
|
113 |
for entry in schema_entries
|
114 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
tables = [tbl[0] for tbl in cursor.fetchall()]
|
50 |
|
51 |
for table in tables:
|
52 |
+
entries_per_table = []
|
53 |
# Retrieve the CREATE TABLE statement for each table
|
54 |
cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
|
55 |
create_table_stmt = cursor.fetchone()
|
56 |
if create_table_stmt:
|
57 |
+
stmt = create_table_stmt[0].strip()
|
58 |
+
if not stmt.endswith(';'):
|
59 |
+
stmt += ';'
|
60 |
+
entries_per_table.append(stmt)
|
61 |
|
62 |
if get_insert_into:
|
63 |
# Retrieve all data from the table
|
|
|
74 |
for row in rows[:max_len]:
|
75 |
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
|
76 |
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
|
77 |
+
entries_per_table.append(insert_stmt)
|
78 |
|
79 |
+
if model != None : entries_per_table = us.crop_entries_per_token(entries_per_table, model, prompt)
|
80 |
+
entries.extend(entries_per_table)
|
81 |
|
82 |
return entries
|
83 |
|
|
|
117 |
)
|
118 |
for entry in schema_entries
|
119 |
)
|
120 |
+
|
121 |
+
|
122 |
+
def create_db_temp(schema_sql: str) -> sqlite3.Connection:
|
123 |
+
"""
|
124 |
+
Creates a temporary SQLite database in memory by executing the provided SQL schema.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
schema_sql (str): The SQL code containing CREATE TABLE and INSERT INTO.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
sqlite3.Connection: Connection object to the temporary database.
|
131 |
+
"""
|
132 |
+
conn = sqlite3.connect(':memory:')
|
133 |
+
cursor = conn.cursor()
|
134 |
+
|
135 |
+
try:
|
136 |
+
cursor.executescript(schema_sql)
|
137 |
+
conn.commit()
|
138 |
+
except sqlite3.Error as e:
|
139 |
+
conn.close()
|
140 |
+
raise
|
141 |
+
|
142 |
+
return conn
|