import os import sys from pathlib import Path from datetime import datetime import json import traceback # Add the necessary directories to the Python path current_dir = Path(__file__).resolve().parent duckdb_nsql_dir = current_dir / 'duckdb-nsql' eval_dir = duckdb_nsql_dir / 'eval' sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)]) # Import necessary functions and classes from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql from eval.evaluate import evaluate, compute_metrics, get_to_print from eval.evaluate import test_suite_evaluation, read_tables_json from eval.schema import TextToSQLParams, Table AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys()) def run_prediction(model_name, prompt_format, output_file): dataset_path = str(eval_dir / "data/dev.json") table_meta_path = str(eval_dir / "data/tables.json") stop_tokens = [';'] max_tokens = 30000 temperature = 0.1 num_beams = -1 manifest_client = "openrouter" manifest_engine = model_name manifest_connection = "http://localhost:5000" overwrite_manifest = True parallel = False yield "Starting prediction..." try: # Initialize necessary components data_formatter = DefaultLoader() prompt_formatter = PROMPT_FORMATTERS[prompt_format]() # Load manifest manifest = get_manifest( manifest_client=manifest_client, manifest_connection=manifest_connection, manifest_engine=manifest_engine, ) # Load data data = data_formatter.load_data(dataset_path) db_to_tables = data_formatter.load_table_metadata(table_meta_path) # Prepare input for generate_sql text_to_sql_inputs = [] for input_question in data: question = input_question["question"] db_id = input_question.get("db_id", "none") if db_id != "none": table_params = list(db_to_tables.get(db_id, {}).values()) else: table_params = [] if len(table_params) == 0: yield f"[red] WARNING: No tables found for {db_id} [/red]" text_to_sql_inputs.append(TextToSQLParams( instruction=question, database=db_id, tables=table_params, )) # Generate SQL generated_sqls = generate_sql( manifest=manifest, text_to_sql_in=text_to_sql_inputs, retrieved_docs=[[] for _ in text_to_sql_inputs], # Assuming no retrieved docs prompt_formatter=prompt_formatter, stop_tokens=stop_tokens, overwrite_manifest=overwrite_manifest, max_tokens=max_tokens, temperature=temperature, num_beams=num_beams, parallel=parallel ) # Save results with output_file.open('w') as f: for original_data, (sql, _) in zip(data, generated_sqls): output = {**original_data, "pred": sql} json.dump(output, f) f.write('\n') yield f"Prediction completed. Results saved to {output_file}" except Exception as e: yield f"Prediction failed with error: {str(e)}" yield f"Error traceback: {traceback.format_exc()}" def run_evaluation(model_name, prompt_format="duckdbinstgraniteshort"): if "OPENROUTER_API_KEY" not in os.environ: yield "Error: OPENROUTER_API_KEY not found in environment variables." return try: # Set up the arguments dataset_path = str(eval_dir / "data/dev.json") table_meta_path = str(eval_dir / "data/tables.json") output_dir = eval_dir / "output" yield f"Using model: {model_name}" yield f"Using prompt format: {prompt_format}" output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json" # Ensure the output directory exists output_dir.mkdir(parents=True, exist_ok=True) if output_file.exists(): yield f"Prediction file already exists: {output_file}" yield "Skipping prediction step and proceeding to evaluation." else: # Run prediction for output in run_prediction(model_name, prompt_format, output_file): yield output # Run evaluation yield "Starting evaluation..." # Set up evaluation arguments gold_path = Path(dataset_path) db_dir = str(eval_dir / "data/databases/") tables_path = Path(table_meta_path) kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) db_schemas = read_tables_json(str(tables_path)) gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()] gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] pred_sqls = [p["pred"] for p in pred_sqls_dict] categories = [p.get("category", "") for p in gold_sqls_dict] yield "Computing metrics..." metrics = compute_metrics( gold_sqls=gold_sqls, pred_sqls=pred_sqls, gold_dbs=gold_dbs, setup_sqls=setup_sqls, validate_sqls=validate_sqls, kmaps=kmaps, db_schemas=db_schemas, database_dir=db_dir, lowercase_schema_match=False, model_name=model_name, categories=categories, ) yield "Evaluation completed." if metrics: yield "Overall Results:" overall_metrics = metrics['exec']['all'] yield f"Count: {overall_metrics['count']}" yield f"Execution Accuracy: {overall_metrics['exec']:.3f}" yield f"Exact Match Accuracy: {overall_metrics['exact']:.3f}" yield f"Equality: {metrics['equality']['equality']:.3f}" yield f"Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}" yield "\nResults by Category:" categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] for category in categories: if category in metrics['exec']: yield f"\n{category}:" category_metrics = metrics['exec'][category] yield f"Count: {category_metrics['count']}" yield f"Execution Accuracy: {category_metrics['exec']:.3f}" else: yield f"\n{category}: No data available" else: yield "No evaluation metrics returned." except Exception as e: yield f"An unexpected error occurred: {str(e)}" yield f"Error traceback: {traceback.format_exc()}" if __name__ == "__main__": model_name = input("Enter the model name: ") prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort" for result in run_evaluation(model_name, prompt_format): print(result, flush=True)