Spaces:
Sleeping
Sleeping
added prompt template and openai api key
Browse files- app.py +100 -20
- duckdb-nsql/eval/constants.py +3 -1
- duckdb-nsql/eval/prompt_formatters.py +29 -0
- evaluation_logic.py +14 -6
app.py
CHANGED
|
@@ -1,40 +1,120 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
output = []
|
| 6 |
-
for result in run_evaluation(inference_api, str(model_name).strip(), prompt_format):
|
| 7 |
output.append(result)
|
| 8 |
yield "\n".join(output)
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
with gr.Blocks(gr.themes.Soft()) as demo:
|
| 11 |
gr.Markdown("# DuckDB SQL Evaluation App")
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
prompt_format = gr.Dropdown(
|
| 22 |
-
label="Prompt Format",
|
| 23 |
-
choices=['duckdbinst', 'duckdbinstgraniteshort'], #AVAILABLE_PROMPT_FORMATS,
|
| 24 |
-
value="duckdbinstgraniteshort"
|
| 25 |
-
)
|
| 26 |
gr.Examples(
|
| 27 |
examples=[
|
| 28 |
-
["openrouter", "qwen/qwen-2.5-72b-instruct", "duckdbinst"],
|
| 29 |
-
["openrouter", "meta-llama/llama-3.2-3b-instruct:free", "duckdbinstgraniteshort"],
|
| 30 |
-
["openrouter", "mistralai/mistral-nemo", "duckdbinst"],
|
| 31 |
],
|
| 32 |
-
inputs=[inference_api, model_name, prompt_format],
|
| 33 |
)
|
| 34 |
|
| 35 |
start_btn = gr.Button("Start Evaluation")
|
| 36 |
output = gr.Textbox(label="Output", lines=20)
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
demo.queue().launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
from evaluation_logic import run_evaluation
|
| 4 |
+
from eval.predict import PROMPT_FORMATTERS
|
| 5 |
|
| 6 |
+
PROMPT_TEMPLATES = {
|
| 7 |
+
"duckdbinstgraniteshort": PROMPT_FORMATTERS["duckdbinstgraniteshort"]().PROMPT_TEMPLATE,
|
| 8 |
+
"duckdbinst": PROMPT_FORMATTERS["duckdbinst"]().PROMPT_TEMPLATE,
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
def gradio_run_evaluation(inference_api, model_name, prompt_format, openrouter_token=None, custom_prompt=None):
|
| 12 |
+
# Set environment variable if OpenRouter token is provided
|
| 13 |
+
if inference_api == "openrouter":
|
| 14 |
+
os.environ["OPENROUTER_API_KEY"] = str(openrouter_token)
|
| 15 |
+
|
| 16 |
+
# We now pass both the format name and content to evaluation
|
| 17 |
output = []
|
| 18 |
+
for result in run_evaluation(inference_api, str(model_name).strip(), prompt_format, custom_prompt):
|
| 19 |
output.append(result)
|
| 20 |
yield "\n".join(output)
|
| 21 |
|
| 22 |
+
def update_token_visibility(api):
|
| 23 |
+
"""Update visibility of the OpenRouter token input"""
|
| 24 |
+
return gr.update(visible=api == "openrouter")
|
| 25 |
+
|
| 26 |
+
def update_prompt_template(prompt_format):
|
| 27 |
+
"""Update the template content when a preset is selected"""
|
| 28 |
+
if prompt_format in PROMPT_TEMPLATES:
|
| 29 |
+
return PROMPT_FORMATTERS[prompt_format]()
|
| 30 |
+
return ""
|
| 31 |
+
|
| 32 |
+
def handle_template_edit(prompt_format, new_template):
|
| 33 |
+
"""Handle when user edits the template"""
|
| 34 |
+
# If the template matches a preset exactly, keep the preset name
|
| 35 |
+
for format_name, template in PROMPT_TEMPLATES.items():
|
| 36 |
+
if template.strip() == new_template.strip():
|
| 37 |
+
return format_name
|
| 38 |
+
# Otherwise switch to custom
|
| 39 |
+
return "custom"
|
| 40 |
+
|
| 41 |
with gr.Blocks(gr.themes.Soft()) as demo:
|
| 42 |
gr.Markdown("# DuckDB SQL Evaluation App")
|
| 43 |
|
| 44 |
+
with gr.Row():
|
| 45 |
+
with gr.Column():
|
| 46 |
+
inference_api = gr.Dropdown(
|
| 47 |
+
label="Inference API",
|
| 48 |
+
choices=['openrouter'],
|
| 49 |
+
value="openrouter"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
openrouter_token = gr.Textbox(
|
| 53 |
+
label="OpenRouter API Token",
|
| 54 |
+
placeholder="Enter your OpenRouter API token",
|
| 55 |
+
type="password",
|
| 56 |
+
visible=True
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
model_name = gr.Textbox(
|
| 60 |
+
label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
gr.Markdown("[View OpenRouter Models](https://openrouter.ai/models?order=top-weekly)")
|
| 64 |
+
|
| 65 |
+
with gr.Row():
|
| 66 |
+
with gr.Column():
|
| 67 |
+
# Add 'custom' to the choices
|
| 68 |
+
prompt_format = gr.Dropdown(
|
| 69 |
+
label="Prompt Format",
|
| 70 |
+
choices=['duckdbinst', 'duckdbinstgraniteshort', 'custom'],
|
| 71 |
+
value="duckdbinstgraniteshort"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
custom_prompt = gr.TextArea(
|
| 75 |
+
label="Prompt Template Content",
|
| 76 |
+
placeholder="Enter your custom prompt template here or select a preset format above.",
|
| 77 |
+
lines=10,
|
| 78 |
+
value=PROMPT_TEMPLATES['duckdbinstgraniteshort'] # Set initial value
|
| 79 |
+
)
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
gr.Examples(
|
| 82 |
examples=[
|
| 83 |
+
["openrouter", "qwen/qwen-2.5-72b-instruct", "duckdbinst", "", PROMPT_TEMPLATES['duckdbinst']],
|
| 84 |
+
["openrouter", "meta-llama/llama-3.2-3b-instruct:free", "duckdbinstgraniteshort", "", PROMPT_TEMPLATES['duckdbinstgraniteshort']],
|
| 85 |
+
["openrouter", "mistralai/mistral-nemo", "duckdbinst", "", PROMPT_TEMPLATES['duckdbinst']],
|
| 86 |
],
|
| 87 |
+
inputs=[inference_api, model_name, prompt_format, openrouter_token, custom_prompt],
|
| 88 |
)
|
| 89 |
|
| 90 |
start_btn = gr.Button("Start Evaluation")
|
| 91 |
output = gr.Textbox(label="Output", lines=20)
|
| 92 |
|
| 93 |
+
# Update token visibility
|
| 94 |
+
inference_api.change(
|
| 95 |
+
fn=update_token_visibility,
|
| 96 |
+
inputs=[inference_api],
|
| 97 |
+
outputs=[openrouter_token]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Update template content when preset is selected
|
| 101 |
+
prompt_format.change(
|
| 102 |
+
fn=update_prompt_template,
|
| 103 |
+
inputs=[prompt_format],
|
| 104 |
+
outputs=[custom_prompt]
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Update format dropdown when template is edited
|
| 108 |
+
custom_prompt.change(
|
| 109 |
+
fn=handle_template_edit,
|
| 110 |
+
inputs=[prompt_format, custom_prompt],
|
| 111 |
+
outputs=[prompt_format]
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
start_btn.click(
|
| 115 |
+
fn=gradio_run_evaluation,
|
| 116 |
+
inputs=[inference_api, model_name, prompt_format, openrouter_token, custom_prompt],
|
| 117 |
+
outputs=output
|
| 118 |
+
)
|
| 119 |
|
| 120 |
demo.queue().launch()
|
duckdb-nsql/eval/constants.py
CHANGED
|
@@ -16,6 +16,7 @@ from prompt_formatters import (
|
|
| 16 |
DuckDBInstFormatterGPTmini,
|
| 17 |
DuckDBInstFormatterPhiAzure,
|
| 18 |
DuckDBInstFormatterLlamaSyntax,
|
|
|
|
| 19 |
)
|
| 20 |
|
| 21 |
PROMPT_FORMATTERS = {
|
|
@@ -33,5 +34,6 @@ PROMPT_FORMATTERS = {
|
|
| 33 |
"duckdbinstgptmini": DuckDBInstFormatterPhi,
|
| 34 |
"duckdbinstphiazure": DuckDBInstFormatterPhiAzure,
|
| 35 |
"duckdbinstllamabasic": DuckDBInstFormatterLlamaBasic,
|
| 36 |
-
"duckdbinstllamasyntax": DuckDBInstFormatterLlamaSyntax
|
|
|
|
| 37 |
}
|
|
|
|
| 16 |
DuckDBInstFormatterGPTmini,
|
| 17 |
DuckDBInstFormatterPhiAzure,
|
| 18 |
DuckDBInstFormatterLlamaSyntax,
|
| 19 |
+
DuckDBInstFormatterCustom,
|
| 20 |
)
|
| 21 |
|
| 22 |
PROMPT_FORMATTERS = {
|
|
|
|
| 34 |
"duckdbinstgptmini": DuckDBInstFormatterPhi,
|
| 35 |
"duckdbinstphiazure": DuckDBInstFormatterPhiAzure,
|
| 36 |
"duckdbinstllamabasic": DuckDBInstFormatterLlamaBasic,
|
| 37 |
+
"duckdbinstllamasyntax": DuckDBInstFormatterLlamaSyntax,
|
| 38 |
+
"custom": DuckDBInstFormatterCustom
|
| 39 |
}
|
duckdb-nsql/eval/prompt_formatters.py
CHANGED
|
@@ -958,6 +958,35 @@ Write a DuckDB SQL query for the given question!
|
|
| 958 |
)
|
| 959 |
return instruction
|
| 960 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
class DuckDBInstNoShorthandFormatter(DuckDBInstFormatter):
|
| 962 |
"""DuckDB Inst class."""
|
| 963 |
|
|
|
|
| 958 |
)
|
| 959 |
return instruction
|
| 960 |
|
| 961 |
+
|
| 962 |
+
class DuckDBInstFormatterCustom(RajkumarFormatter):
|
| 963 |
+
"""DuckDB Inst class."""
|
| 964 |
+
|
| 965 |
+
PROMPT_TEMPLATE = ""
|
| 966 |
+
|
| 967 |
+
@classmethod
|
| 968 |
+
def format_retrieved_context(
|
| 969 |
+
cls,
|
| 970 |
+
context: list[str],
|
| 971 |
+
) -> str:
|
| 972 |
+
"""Format retrieved context."""
|
| 973 |
+
context_str = "\n--------\n".join(context)
|
| 974 |
+
return f"\n### Documentation:\n{context_str}\n"
|
| 975 |
+
|
| 976 |
+
@classmethod
|
| 977 |
+
def format_prompt(
|
| 978 |
+
cls,
|
| 979 |
+
instruction: str,
|
| 980 |
+
table_text: str,
|
| 981 |
+
context_text: str,
|
| 982 |
+
) -> str | list[str]:
|
| 983 |
+
"""Get prompt format."""
|
| 984 |
+
instruction = cls.PROMPT_TEMPLATE.format(
|
| 985 |
+
schema=table_text,
|
| 986 |
+
question=instruction
|
| 987 |
+
)
|
| 988 |
+
return instruction
|
| 989 |
+
|
| 990 |
class DuckDBInstNoShorthandFormatter(DuckDBInstFormatter):
|
| 991 |
"""DuckDB Inst class."""
|
| 992 |
|
evaluation_logic.py
CHANGED
|
@@ -54,7 +54,7 @@ def save_prediction(inference_api, model_name, prompt_format, question, generate
|
|
| 54 |
"timestamp": datetime.now().isoformat()
|
| 55 |
}, f)
|
| 56 |
|
| 57 |
-
def save_evaluation(inference_api, model_name, prompt_format, metrics):
|
| 58 |
evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json"
|
| 59 |
evaluation_folder.mkdir(parents=True, exist_ok=True)
|
| 60 |
|
|
@@ -64,6 +64,7 @@ def save_evaluation(inference_api, model_name, prompt_format, metrics):
|
|
| 64 |
"inference_api": inference_api,
|
| 65 |
"model_name": model_name,
|
| 66 |
"prompt_format": prompt_format,
|
|
|
|
| 67 |
"timestamp": datetime.now().isoformat()
|
| 68 |
}
|
| 69 |
|
|
@@ -82,7 +83,7 @@ def save_evaluation(inference_api, model_name, prompt_format, metrics):
|
|
| 82 |
json.dump(flattened_metrics, f)
|
| 83 |
f.write('\n')
|
| 84 |
|
| 85 |
-
def run_prediction(inference_api, model_name, prompt_format, output_file):
|
| 86 |
dataset_path = str(eval_dir / "data/dev.json")
|
| 87 |
table_meta_path = str(eval_dir / "data/tables.json")
|
| 88 |
stop_tokens = [';']
|
|
@@ -100,7 +101,11 @@ def run_prediction(inference_api, model_name, prompt_format, output_file):
|
|
| 100 |
try:
|
| 101 |
# Initialize necessary components
|
| 102 |
data_formatter = DefaultLoader()
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
# Load manifest
|
| 106 |
manifest = get_manifest(
|
|
@@ -159,7 +164,7 @@ def run_prediction(inference_api, model_name, prompt_format, output_file):
|
|
| 159 |
yield f"Prediction failed with error: {str(e)}"
|
| 160 |
yield f"Error traceback: {traceback.format_exc()}"
|
| 161 |
|
| 162 |
-
def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort"):
|
| 163 |
if "OPENROUTER_API_KEY" not in os.environ:
|
| 164 |
yield "Error: OPENROUTER_API_KEY not found in environment variables."
|
| 165 |
return
|
|
@@ -176,6 +181,9 @@ def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgranitesh
|
|
| 176 |
yield f"Using model: {model_name}"
|
| 177 |
yield f"Using prompt format: {prompt_format}"
|
| 178 |
|
|
|
|
|
|
|
|
|
|
| 179 |
output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"
|
| 180 |
|
| 181 |
# Ensure the output directory exists
|
|
@@ -186,7 +194,7 @@ def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgranitesh
|
|
| 186 |
yield "Skipping prediction step and proceeding to evaluation."
|
| 187 |
else:
|
| 188 |
# Run prediction
|
| 189 |
-
for output in run_prediction(inference_api, model_name, prompt_format, output_file):
|
| 190 |
yield output
|
| 191 |
|
| 192 |
# Run evaluation
|
|
@@ -226,7 +234,7 @@ def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgranitesh
|
|
| 226 |
)
|
| 227 |
|
| 228 |
# Save evaluation results to dataset
|
| 229 |
-
save_evaluation(inference_api, model_name, prompt_format, metrics)
|
| 230 |
|
| 231 |
yield "Evaluation completed."
|
| 232 |
|
|
|
|
| 54 |
"timestamp": datetime.now().isoformat()
|
| 55 |
}, f)
|
| 56 |
|
| 57 |
+
def save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics):
|
| 58 |
evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json"
|
| 59 |
evaluation_folder.mkdir(parents=True, exist_ok=True)
|
| 60 |
|
|
|
|
| 64 |
"inference_api": inference_api,
|
| 65 |
"model_name": model_name,
|
| 66 |
"prompt_format": prompt_format,
|
| 67 |
+
"custom_prompt": str(custom_prompt),
|
| 68 |
"timestamp": datetime.now().isoformat()
|
| 69 |
}
|
| 70 |
|
|
|
|
| 83 |
json.dump(flattened_metrics, f)
|
| 84 |
f.write('\n')
|
| 85 |
|
| 86 |
+
def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
|
| 87 |
dataset_path = str(eval_dir / "data/dev.json")
|
| 88 |
table_meta_path = str(eval_dir / "data/tables.json")
|
| 89 |
stop_tokens = [';']
|
|
|
|
| 101 |
try:
|
| 102 |
# Initialize necessary components
|
| 103 |
data_formatter = DefaultLoader()
|
| 104 |
+
if prompt_format.startswith("custom"):
|
| 105 |
+
prompt_formatter = PROMPT_FORMATTERS["custom"]()
|
| 106 |
+
prompt_formatter.PROMPT_TEMPLATE = custom_prompt
|
| 107 |
+
else:
|
| 108 |
+
prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
|
| 109 |
|
| 110 |
# Load manifest
|
| 111 |
manifest = get_manifest(
|
|
|
|
| 164 |
yield f"Prediction failed with error: {str(e)}"
|
| 165 |
yield f"Error traceback: {traceback.format_exc()}"
|
| 166 |
|
| 167 |
+
def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort", custom_prompt=None):
|
| 168 |
if "OPENROUTER_API_KEY" not in os.environ:
|
| 169 |
yield "Error: OPENROUTER_API_KEY not found in environment variables."
|
| 170 |
return
|
|
|
|
| 181 |
yield f"Using model: {model_name}"
|
| 182 |
yield f"Using prompt format: {prompt_format}"
|
| 183 |
|
| 184 |
+
if prompt_format == "custom":
|
| 185 |
+
prompt_format = prompt_format+"_"+str(abs(hash(custom_prompt)) % (10 ** 8))
|
| 186 |
+
|
| 187 |
output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"
|
| 188 |
|
| 189 |
# Ensure the output directory exists
|
|
|
|
| 194 |
yield "Skipping prediction step and proceeding to evaluation."
|
| 195 |
else:
|
| 196 |
# Run prediction
|
| 197 |
+
for output in run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
|
| 198 |
yield output
|
| 199 |
|
| 200 |
# Run evaluation
|
|
|
|
| 234 |
)
|
| 235 |
|
| 236 |
# Save evaluation results to dataset
|
| 237 |
+
save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics)
|
| 238 |
|
| 239 |
yield "Evaluation completed."
|
| 240 |
|