Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from pathlib import Path | |
| from datetime import datetime | |
| import json | |
| import traceback | |
| # Add the necessary directories to the Python path | |
| current_dir = Path(__file__).resolve().parent | |
| duckdb_nsql_dir = current_dir / 'duckdb-nsql' | |
| eval_dir = duckdb_nsql_dir / 'eval' | |
| sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)]) | |
| # Import necessary functions and classes | |
| from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql | |
| from eval.evaluate import evaluate, compute_metrics, get_to_print | |
| from eval.evaluate import test_suite_evaluation, read_tables_json | |
| from eval.schema import TextToSQLParams, Table | |
| AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys()) | |
| def run_prediction(model_name, prompt_format, output_file): | |
| dataset_path = str(eval_dir / "data/dev.json") | |
| table_meta_path = str(eval_dir / "data/tables.json") | |
| stop_tokens = [';'] | |
| max_tokens = 30000 | |
| temperature = 0.1 | |
| num_beams = -1 | |
| manifest_client = "openrouter" | |
| manifest_engine = model_name | |
| manifest_connection = "http://localhost:5000" | |
| overwrite_manifest = True | |
| parallel = False | |
| yield "Starting prediction..." | |
| try: | |
| # Initialize necessary components | |
| data_formatter = DefaultLoader() | |
| prompt_formatter = PROMPT_FORMATTERS[prompt_format]() | |
| # Load manifest | |
| manifest = get_manifest( | |
| manifest_client=manifest_client, | |
| manifest_connection=manifest_connection, | |
| manifest_engine=manifest_engine, | |
| ) | |
| # Load data | |
| data = data_formatter.load_data(dataset_path) | |
| db_to_tables = data_formatter.load_table_metadata(table_meta_path) | |
| # Prepare input for generate_sql | |
| text_to_sql_inputs = [] | |
| for input_question in data: | |
| question = input_question["question"] | |
| db_id = input_question.get("db_id", "none") | |
| if db_id != "none": | |
| table_params = list(db_to_tables.get(db_id, {}).values()) | |
| else: | |
| table_params = [] | |
| if len(table_params) == 0: | |
| yield f"[red] WARNING: No tables found for {db_id} [/red]" | |
| # Convert Table objects to dictionaries if they're already instantiated | |
| processed_table_params = [] | |
| for table in table_params: | |
| if isinstance(table, Table): | |
| processed_table_params.append(table.dict()) | |
| else: | |
| processed_table_params.append(table) | |
| text_to_sql_inputs.append(TextToSQLParams( | |
| instruction=question, | |
| database=db_id, | |
| tables=processed_table_params, | |
| )) | |
| # Generate SQL | |
| generated_sqls = generate_sql( | |
| manifest=manifest, | |
| text_to_sql_in=text_to_sql_inputs, | |
| retrieved_docs=[[] for _ in text_to_sql_inputs], # Assuming no retrieved docs | |
| prompt_formatter=prompt_formatter, | |
| stop_tokens=stop_tokens, | |
| overwrite_manifest=overwrite_manifest, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| num_beams=num_beams, | |
| parallel=parallel | |
| ) | |
| # Save results | |
| with output_file.open('w') as f: | |
| for original_data, (sql, _) in zip(data, generated_sqls): | |
| output = {**original_data, "pred": sql} | |
| json.dump(output, f) | |
| f.write('\n') | |
| yield f"Prediction completed. Results saved to {output_file}" | |
| except Exception as e: | |
| yield f"Prediction failed with error: {str(e)}" | |
| yield f"Error traceback: {traceback.format_exc()}" | |
| def run_evaluation(model_name, prompt_format="duckdbinstgraniteshort"): | |
| if "OPENROUTER_API_KEY" not in os.environ: | |
| yield "Error: OPENROUTER_API_KEY not found in environment variables." | |
| return | |
| try: | |
| # Set up the arguments | |
| dataset_path = str(eval_dir / "data/dev.json") | |
| table_meta_path = str(eval_dir / "data/tables.json") | |
| output_dir = eval_dir / "output" | |
| yield f"Using model: {model_name}" | |
| yield f"Using prompt format: {prompt_format}" | |
| output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json" | |
| # Ensure the output directory exists | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| if output_file.exists(): | |
| yield f"Prediction file already exists: {output_file}" | |
| yield "Skipping prediction step and proceeding to evaluation." | |
| else: | |
| # Run prediction | |
| for output in run_prediction(model_name, prompt_format, output_file): | |
| yield output | |
| # Run evaluation | |
| yield "Starting evaluation..." | |
| # Set up evaluation arguments | |
| gold_path = Path(dataset_path) | |
| db_dir = str(eval_dir / "data/databases/") | |
| tables_path = Path(table_meta_path) | |
| kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) | |
| db_schemas = read_tables_json(str(tables_path)) | |
| gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) | |
| pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()] | |
| gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] | |
| setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] | |
| validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] | |
| gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] | |
| pred_sqls = [p["pred"] for p in pred_sqls_dict] | |
| categories = [p.get("category", "") for p in gold_sqls_dict] | |
| yield "Computing metrics..." | |
| metrics = compute_metrics( | |
| gold_sqls=gold_sqls, | |
| pred_sqls=pred_sqls, | |
| gold_dbs=gold_dbs, | |
| setup_sqls=setup_sqls, | |
| validate_sqls=validate_sqls, | |
| kmaps=kmaps, | |
| db_schemas=db_schemas, | |
| database_dir=db_dir, | |
| lowercase_schema_match=False, | |
| model_name=model_name, | |
| categories=categories, | |
| ) | |
| yield "Evaluation completed." | |
| if metrics: | |
| yield "Overall Results:" | |
| overall_metrics = metrics['exec']['all'] | |
| yield f"Count: {overall_metrics['count']}" | |
| yield f"Execution Accuracy: {overall_metrics['exec']:.3f}" | |
| yield f"Exact Match Accuracy: {overall_metrics['exact']:.3f}" | |
| yield f"Equality: {metrics['equality']['equality']:.3f}" | |
| yield f"Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}" | |
| yield "\nResults by Category:" | |
| categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] | |
| for category in categories: | |
| if category in metrics['exec']: | |
| yield f"\n{category}:" | |
| category_metrics = metrics['exec'][category] | |
| yield f"Count: {category_metrics['count']}" | |
| yield f"Execution Accuracy: {category_metrics['exec']:.3f}" | |
| else: | |
| yield f"\n{category}: No data available" | |
| else: | |
| yield "No evaluation metrics returned." | |
| except Exception as e: | |
| yield f"An unexpected error occurred: {str(e)}" | |
| yield f"Error traceback: {traceback.format_exc()}" | |
| if __name__ == "__main__": | |
| model_name = input("Enter the model name: ") | |
| prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort" | |
| for result in run_evaluation(model_name, prompt_format): | |
| print(result, flush=True) |