import os import sys from pathlib import Path from datetime import datetime import json import traceback import uuid from huggingface_hub import CommitScheduler current_dir = Path(__file__).resolve().parent duckdb_nsql_dir = current_dir / 'duckdb-nsql' eval_dir = duckdb_nsql_dir / 'eval' sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)]) from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql from eval.evaluate import evaluate, compute_metrics, get_to_print from eval.evaluate import test_suite_evaluation, read_tables_json from eval.schema import TextToSQLParams, Table AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys()) prediction_folder = Path("prediction_results/") evaluation_folder = Path("evaluation_results/") file_uuid = uuid.uuid4() prediction_scheduler = CommitScheduler( repo_id="sql-console/duckdb-nsql-predictions", repo_type="dataset", folder_path=prediction_folder, path_in_repo="data", every=10, ) evaluation_scheduler = CommitScheduler( repo_id="sql-console/duckdb-nsql-scores", repo_type="dataset", folder_path=evaluation_folder, path_in_repo="data", every=10, ) def save_prediction(inference_api, model_name, prompt_format, question, generated_sql): prediction_file = prediction_folder / f"prediction_{file_uuid}.json" prediction_folder.mkdir(parents=True, exist_ok=True) with prediction_scheduler.lock: with prediction_file.open("a") as f: json.dump({ "inference_api": inference_api, "model_name": model_name, "prompt_format": prompt_format, "question": question, "generated_sql": generated_sql, "timestamp": datetime.now().isoformat() }, f) def save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics): evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json" evaluation_folder.mkdir(parents=True, exist_ok=True) # Extract and flatten the category-specific execution metrics categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] flattened_metrics = { "inference_api": inference_api, "model_name": model_name, "prompt_format": prompt_format, "custom_prompt": str(custom_prompt) if prompt_format.startswith("custom") else "", "timestamp": datetime.now().isoformat() } # Flatten each category's metrics into separate columns for category in categories: if category in metrics['exec']: category_metrics = metrics['exec'][category] flattened_metrics[f"{category}_count"] = category_metrics['count'] flattened_metrics[f"{category}_execution_accuracy"] = category_metrics['exec'] else: flattened_metrics[f"{category}_count"] = 0 flattened_metrics[f"{category}_execution_accuracy"] = 0.0 with evaluation_scheduler.lock: with evaluation_file.open("a") as f: json.dump(flattened_metrics, f) f.write('\n') def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file): dataset_path = str(eval_dir / "data/dev.json") table_meta_path = str(eval_dir / "data/tables.json") stop_tokens = [';'] max_tokens = 30000 temperature = 0.1 num_beams = -1 manifest_client = inference_api manifest_engine = model_name manifest_connection = "http://localhost:5000" overwrite_manifest = True parallel = False yield "Starting prediction..." try: # Initialize necessary components data_formatter = DefaultLoader() if prompt_format.startswith("custom"): prompt_formatter_cls = PROMPT_FORMATTERS["custom"] prompt_formatter_cls.PROMPT_TEMPLATE = custom_prompt prompt_formatter = prompt_formatter_cls() else: prompt_formatter = PROMPT_FORMATTERS[prompt_format]() # Load manifest manifest = get_manifest( manifest_client=manifest_client, manifest_connection=manifest_connection, manifest_engine=manifest_engine, ) # Load data data = data_formatter.load_data(dataset_path) db_to_tables = data_formatter.load_table_metadata(table_meta_path) # Prepare input for generate_sql text_to_sql_inputs = [] for input_question in data: question = input_question["question"] db_id = input_question.get("db_id", "none") if db_id != "none": table_params = list(db_to_tables.get(db_id, {}).values()) else: table_params = [] text_to_sql_inputs.append(TextToSQLParams( instruction=question, database=db_id, tables=table_params, )) # Generate SQL generated_sqls = generate_sql( manifest=manifest, text_to_sql_in=text_to_sql_inputs, retrieved_docs=[[] for _ in text_to_sql_inputs], prompt_formatter=prompt_formatter, stop_tokens=stop_tokens, overwrite_manifest=overwrite_manifest, max_tokens=max_tokens, temperature=temperature, num_beams=num_beams, parallel=parallel ) # Save results output_file.parent.mkdir(parents=True, exist_ok=True) with output_file.open('w') as f: for original_data, (sql, _) in zip(data, generated_sqls): output = {**original_data, "pred": sql} json.dump(output, f) f.write('\n') # Save prediction to dataset save_prediction(inference_api, model_name, prompt_format, original_data["question"], sql) yield f"Prediction completed. Results saved to {output_file}" except Exception as e: yield f"Prediction failed with error: {str(e)}" yield f"Error traceback: {traceback.format_exc()}" def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort", custom_prompt=None): if "OPENROUTER_API_KEY" not in os.environ: yield "Error: OPENROUTER_API_KEY not found in environment variables." return if "HF_TOKEN" not in os.environ: yield "Error: HF_TOKEN not found in environment variables." return try: # Set up the arguments dataset_path = str(eval_dir / "data/dev.json") table_meta_path = str(eval_dir / "data/tables.json") output_dir = eval_dir / "output" yield f"Using model: {model_name}" yield f"Using prompt format: {prompt_format}" if prompt_format == "custom": prompt_format = prompt_format+"_"+str(abs(hash(custom_prompt)) % (10 ** 8)) output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json" # Ensure the output directory exists output_dir.mkdir(parents=True, exist_ok=True) if output_file.exists(): yield f"Prediction file already exists: {output_file}" yield "Skipping prediction step and proceeding to evaluation." else: # Run prediction for output in run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file): yield output # Run evaluation yield "Starting evaluation..." # Set up evaluation arguments gold_path = Path(dataset_path) db_dir = str(eval_dir / "data/databases/") tables_path = Path(table_meta_path) kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) db_schemas = read_tables_json(str(tables_path)) gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()] gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] pred_sqls = [p["pred"] for p in pred_sqls_dict] categories = [p.get("category", "") for p in gold_sqls_dict] yield "Computing metrics..." metrics = compute_metrics( gold_sqls=gold_sqls, pred_sqls=pred_sqls, gold_dbs=gold_dbs, setup_sqls=setup_sqls, validate_sqls=validate_sqls, kmaps=kmaps, db_schemas=db_schemas, database_dir=db_dir, lowercase_schema_match=False, model_name=model_name, categories=categories, ) # Save evaluation results to dataset save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics) yield "Evaluation completed." if metrics: yield "Overall Results:" overall_metrics = metrics['exec']['all'] yield f"All (n={overall_metrics['count']}) - Execution Accuracy: {overall_metrics['exec']:.3f}" yield f"All (n={overall_metrics['count']}) - Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}" categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] for category in categories: if category in metrics['exec']: category_metrics = metrics['exec'][category] yield f"{category} (n={category_metrics['count']}) - Execution Accuracy: {category_metrics['exec']:.3f}" else: yield f"{category}: No data available" else: yield "No evaluation metrics returned." except Exception as e: yield f"An unexpected error occurred: {str(e)}" yield f"Error traceback: {traceback.format_exc()}" if __name__ == "__main__": model_name = input("Enter the model name: ") prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort" for result in run_evaluation(model_name, prompt_format): print(result, flush=True)