DuckDB-SQL-Eval / evaluation_logic.py
tdoehmen's picture
set class var
c010ad7
raw
history blame
10.3 kB
import os
import sys
from pathlib import Path
from datetime import datetime
import json
import traceback
import uuid
from huggingface_hub import CommitScheduler
current_dir = Path(__file__).resolve().parent
duckdb_nsql_dir = current_dir / 'duckdb-nsql'
eval_dir = duckdb_nsql_dir / 'eval'
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])
from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql
from eval.evaluate import evaluate, compute_metrics, get_to_print
from eval.evaluate import test_suite_evaluation, read_tables_json
from eval.schema import TextToSQLParams, Table
AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())
prediction_folder = Path("prediction_results/")
evaluation_folder = Path("evaluation_results/")
file_uuid = uuid.uuid4()
prediction_scheduler = CommitScheduler(
repo_id="sql-console/duckdb-nsql-predictions",
repo_type="dataset",
folder_path=prediction_folder,
path_in_repo="data",
every=10,
)
evaluation_scheduler = CommitScheduler(
repo_id="sql-console/duckdb-nsql-scores",
repo_type="dataset",
folder_path=evaluation_folder,
path_in_repo="data",
every=10,
)
def save_prediction(inference_api, model_name, prompt_format, question, generated_sql):
prediction_file = prediction_folder / f"prediction_{file_uuid}.json"
prediction_folder.mkdir(parents=True, exist_ok=True)
with prediction_scheduler.lock:
with prediction_file.open("a") as f:
json.dump({
"inference_api": inference_api,
"model_name": model_name,
"prompt_format": prompt_format,
"question": question,
"generated_sql": generated_sql,
"timestamp": datetime.now().isoformat()
}, f)
def save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics):
evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json"
evaluation_folder.mkdir(parents=True, exist_ok=True)
# Extract and flatten the category-specific execution metrics
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all']
flattened_metrics = {
"inference_api": inference_api,
"model_name": model_name,
"prompt_format": prompt_format,
"custom_prompt": str(custom_prompt),
"timestamp": datetime.now().isoformat()
}
# Flatten each category's metrics into separate columns
for category in categories:
if category in metrics['exec']:
category_metrics = metrics['exec'][category]
flattened_metrics[f"{category}_count"] = category_metrics['count']
flattened_metrics[f"{category}_execution_accuracy"] = category_metrics['exec']
else:
flattened_metrics[f"{category}_count"] = 0
flattened_metrics[f"{category}_execution_accuracy"] = 0.0
with evaluation_scheduler.lock:
with evaluation_file.open("a") as f:
json.dump(flattened_metrics, f)
f.write('\n')
def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
dataset_path = str(eval_dir / "data/dev.json")
table_meta_path = str(eval_dir / "data/tables.json")
stop_tokens = [';']
max_tokens = 30000
temperature = 0.1
num_beams = -1
manifest_client = inference_api
manifest_engine = model_name
manifest_connection = "http://localhost:5000"
overwrite_manifest = True
parallel = False
yield "Starting prediction..."
try:
# Initialize necessary components
data_formatter = DefaultLoader()
if prompt_format.startswith("custom"):
prompt_formatter_cls = PROMPT_FORMATTERS["custom"]
prompt_formatter_cls.PROMPT_TEMPLATE = custom_prompt
prompt_formatter = prompt_formatter_cls()
else:
prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
# Load manifest
manifest = get_manifest(
manifest_client=manifest_client,
manifest_connection=manifest_connection,
manifest_engine=manifest_engine,
)
# Load data
data = data_formatter.load_data(dataset_path)
db_to_tables = data_formatter.load_table_metadata(table_meta_path)
# Prepare input for generate_sql
text_to_sql_inputs = []
for input_question in data:
question = input_question["question"]
db_id = input_question.get("db_id", "none")
if db_id != "none":
table_params = list(db_to_tables.get(db_id, {}).values())
else:
table_params = []
text_to_sql_inputs.append(TextToSQLParams(
instruction=question,
database=db_id,
tables=table_params,
))
# Generate SQL
generated_sqls = generate_sql(
manifest=manifest,
text_to_sql_in=text_to_sql_inputs,
retrieved_docs=[[] for _ in text_to_sql_inputs],
prompt_formatter=prompt_formatter,
stop_tokens=stop_tokens,
overwrite_manifest=overwrite_manifest,
max_tokens=max_tokens,
temperature=temperature,
num_beams=num_beams,
parallel=parallel
)
# Save results
output_file.parent.mkdir(parents=True, exist_ok=True)
with output_file.open('w') as f:
for original_data, (sql, _) in zip(data, generated_sqls):
output = {**original_data, "pred": sql}
json.dump(output, f)
f.write('\n')
# Save prediction to dataset
save_prediction(inference_api, model_name, prompt_format, original_data["question"], sql)
yield f"Prediction completed. Results saved to {output_file}"
except Exception as e:
yield f"Prediction failed with error: {str(e)}"
yield f"Error traceback: {traceback.format_exc()}"
def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort", custom_prompt=None):
if "OPENROUTER_API_KEY" not in os.environ:
yield "Error: OPENROUTER_API_KEY not found in environment variables."
return
if "HF_TOKEN" not in os.environ:
yield "Error: HF_TOKEN not found in environment variables."
return
try:
# Set up the arguments
dataset_path = str(eval_dir / "data/dev.json")
table_meta_path = str(eval_dir / "data/tables.json")
output_dir = eval_dir / "output"
yield f"Using model: {model_name}"
yield f"Using prompt format: {prompt_format}"
if prompt_format == "custom":
prompt_format = prompt_format+"_"+str(abs(hash(custom_prompt)) % (10 ** 8))
output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"
# Ensure the output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
if output_file.exists():
yield f"Prediction file already exists: {output_file}"
yield "Skipping prediction step and proceeding to evaluation."
else:
# Run prediction
for output in run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
yield output
# Run evaluation
yield "Starting evaluation..."
# Set up evaluation arguments
gold_path = Path(dataset_path)
db_dir = str(eval_dir / "data/databases/")
tables_path = Path(table_meta_path)
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
db_schemas = read_tables_json(str(tables_path))
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()]
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
pred_sqls = [p["pred"] for p in pred_sqls_dict]
categories = [p.get("category", "") for p in gold_sqls_dict]
yield "Computing metrics..."
metrics = compute_metrics(
gold_sqls=gold_sqls,
pred_sqls=pred_sqls,
gold_dbs=gold_dbs,
setup_sqls=setup_sqls,
validate_sqls=validate_sqls,
kmaps=kmaps,
db_schemas=db_schemas,
database_dir=db_dir,
lowercase_schema_match=False,
model_name=model_name,
categories=categories,
)
# Save evaluation results to dataset
save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics)
yield "Evaluation completed."
if metrics:
yield "Overall Results:"
overall_metrics = metrics['exec']['all']
yield f"All (n={overall_metrics['count']}) - Execution Accuracy: {overall_metrics['exec']:.3f}"
yield f"All (n={overall_metrics['count']}) - Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}"
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all']
for category in categories:
if category in metrics['exec']:
category_metrics = metrics['exec'][category]
yield f"{category} (n={category_metrics['count']}) - Execution Accuracy: {category_metrics['exec']:.3f}"
else:
yield f"{category}: No data available"
else:
yield "No evaluation metrics returned."
except Exception as e:
yield f"An unexpected error occurred: {str(e)}"
yield f"Error traceback: {traceback.format_exc()}"
if __name__ == "__main__":
model_name = input("Enter the model name: ")
prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort"
for result in run_evaluation(model_name, prompt_format):
print(result, flush=True)