Spaces:
Running
Running
import os | |
import sys | |
from pathlib import Path | |
from datetime import datetime | |
import json | |
import traceback | |
import uuid | |
from huggingface_hub import CommitScheduler | |
current_dir = Path(__file__).resolve().parent | |
duckdb_nsql_dir = current_dir / 'duckdb-nsql' | |
eval_dir = duckdb_nsql_dir / 'eval' | |
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)]) | |
from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql | |
from eval.evaluate import evaluate, compute_metrics, get_to_print | |
from eval.evaluate import test_suite_evaluation, read_tables_json | |
from eval.schema import TextToSQLParams, Table | |
AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys()) | |
prediction_folder = Path("prediction_results/") | |
evaluation_folder = Path("evaluation_results/") | |
file_uuid = uuid.uuid4() | |
prediction_scheduler = CommitScheduler( | |
repo_id="sql-console/duckdb-nsql-predictions", | |
repo_type="dataset", | |
folder_path=prediction_folder, | |
path_in_repo="data", | |
every=10, | |
) | |
evaluation_scheduler = CommitScheduler( | |
repo_id="sql-console/duckdb-nsql-scores", | |
repo_type="dataset", | |
folder_path=evaluation_folder, | |
path_in_repo="data", | |
every=10, | |
) | |
def save_prediction(inference_api, model_name, prompt_format, question, generated_sql): | |
prediction_file = prediction_folder / f"prediction_{file_uuid}.json" | |
prediction_folder.mkdir(parents=True, exist_ok=True) | |
with prediction_scheduler.lock: | |
with prediction_file.open("a") as f: | |
json.dump({ | |
"inference_api": inference_api, | |
"model_name": model_name, | |
"prompt_format": prompt_format, | |
"question": question, | |
"generated_sql": generated_sql, | |
"timestamp": datetime.now().isoformat() | |
}, f) | |
def save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics): | |
evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json" | |
evaluation_folder.mkdir(parents=True, exist_ok=True) | |
# Extract and flatten the category-specific execution metrics | |
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] | |
flattened_metrics = { | |
"inference_api": inference_api, | |
"model_name": model_name, | |
"prompt_format": prompt_format, | |
"custom_prompt": str(custom_prompt), | |
"timestamp": datetime.now().isoformat() | |
} | |
# Flatten each category's metrics into separate columns | |
for category in categories: | |
if category in metrics['exec']: | |
category_metrics = metrics['exec'][category] | |
flattened_metrics[f"{category}_count"] = category_metrics['count'] | |
flattened_metrics[f"{category}_execution_accuracy"] = category_metrics['exec'] | |
else: | |
flattened_metrics[f"{category}_count"] = 0 | |
flattened_metrics[f"{category}_execution_accuracy"] = 0.0 | |
with evaluation_scheduler.lock: | |
with evaluation_file.open("a") as f: | |
json.dump(flattened_metrics, f) | |
f.write('\n') | |
def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file): | |
dataset_path = str(eval_dir / "data/dev.json") | |
table_meta_path = str(eval_dir / "data/tables.json") | |
stop_tokens = [';'] | |
max_tokens = 30000 | |
temperature = 0.1 | |
num_beams = -1 | |
manifest_client = inference_api | |
manifest_engine = model_name | |
manifest_connection = "http://localhost:5000" | |
overwrite_manifest = True | |
parallel = False | |
yield "Starting prediction..." | |
try: | |
# Initialize necessary components | |
data_formatter = DefaultLoader() | |
if prompt_format.startswith("custom"): | |
prompt_formatter = PROMPT_FORMATTERS["custom"]() | |
prompt_formatter.PROMPT_TEMPLATE = custom_prompt | |
else: | |
prompt_formatter = PROMPT_FORMATTERS[prompt_format]() | |
# Load manifest | |
manifest = get_manifest( | |
manifest_client=manifest_client, | |
manifest_connection=manifest_connection, | |
manifest_engine=manifest_engine, | |
) | |
# Load data | |
data = data_formatter.load_data(dataset_path) | |
db_to_tables = data_formatter.load_table_metadata(table_meta_path) | |
# Prepare input for generate_sql | |
text_to_sql_inputs = [] | |
for input_question in data: | |
question = input_question["question"] | |
db_id = input_question.get("db_id", "none") | |
if db_id != "none": | |
table_params = list(db_to_tables.get(db_id, {}).values()) | |
else: | |
table_params = [] | |
text_to_sql_inputs.append(TextToSQLParams( | |
instruction=question, | |
database=db_id, | |
tables=table_params, | |
)) | |
# Generate SQL | |
generated_sqls = generate_sql( | |
manifest=manifest, | |
text_to_sql_in=text_to_sql_inputs, | |
retrieved_docs=[[] for _ in text_to_sql_inputs], | |
prompt_formatter=prompt_formatter, | |
stop_tokens=stop_tokens, | |
overwrite_manifest=overwrite_manifest, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
num_beams=num_beams, | |
parallel=parallel | |
) | |
# Save results | |
output_file.parent.mkdir(parents=True, exist_ok=True) | |
with output_file.open('w') as f: | |
for original_data, (sql, _) in zip(data, generated_sqls): | |
output = {**original_data, "pred": sql} | |
json.dump(output, f) | |
f.write('\n') | |
# Save prediction to dataset | |
save_prediction(inference_api, model_name, prompt_format, original_data["question"], sql) | |
yield f"Prediction completed. Results saved to {output_file}" | |
except Exception as e: | |
yield f"Prediction failed with error: {str(e)}" | |
yield f"Error traceback: {traceback.format_exc()}" | |
def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort", custom_prompt=None): | |
if "OPENROUTER_API_KEY" not in os.environ: | |
yield "Error: OPENROUTER_API_KEY not found in environment variables." | |
return | |
if "HF_TOKEN" not in os.environ: | |
yield "Error: HF_TOKEN not found in environment variables." | |
return | |
try: | |
# Set up the arguments | |
dataset_path = str(eval_dir / "data/dev.json") | |
table_meta_path = str(eval_dir / "data/tables.json") | |
output_dir = eval_dir / "output" | |
yield f"Using model: {model_name}" | |
yield f"Using prompt format: {prompt_format}" | |
if prompt_format == "custom": | |
prompt_format = prompt_format+"_"+str(abs(hash(custom_prompt)) % (10 ** 8)) | |
output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json" | |
# Ensure the output directory exists | |
output_dir.mkdir(parents=True, exist_ok=True) | |
if output_file.exists(): | |
yield f"Prediction file already exists: {output_file}" | |
yield "Skipping prediction step and proceeding to evaluation." | |
else: | |
# Run prediction | |
for output in run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file): | |
yield output | |
# Run evaluation | |
yield "Starting evaluation..." | |
# Set up evaluation arguments | |
gold_path = Path(dataset_path) | |
db_dir = str(eval_dir / "data/databases/") | |
tables_path = Path(table_meta_path) | |
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) | |
db_schemas = read_tables_json(str(tables_path)) | |
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) | |
pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()] | |
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] | |
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] | |
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] | |
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] | |
pred_sqls = [p["pred"] for p in pred_sqls_dict] | |
categories = [p.get("category", "") for p in gold_sqls_dict] | |
yield "Computing metrics..." | |
metrics = compute_metrics( | |
gold_sqls=gold_sqls, | |
pred_sqls=pred_sqls, | |
gold_dbs=gold_dbs, | |
setup_sqls=setup_sqls, | |
validate_sqls=validate_sqls, | |
kmaps=kmaps, | |
db_schemas=db_schemas, | |
database_dir=db_dir, | |
lowercase_schema_match=False, | |
model_name=model_name, | |
categories=categories, | |
) | |
# Save evaluation results to dataset | |
save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics) | |
yield "Evaluation completed." | |
if metrics: | |
yield "Overall Results:" | |
overall_metrics = metrics['exec']['all'] | |
yield f"All (n={overall_metrics['count']}) - Execution Accuracy: {overall_metrics['exec']:.3f}" | |
yield f"All (n={overall_metrics['count']}) - Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}" | |
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] | |
for category in categories: | |
if category in metrics['exec']: | |
category_metrics = metrics['exec'][category] | |
yield f"{category} (n={category_metrics['count']}) - Execution Accuracy: {category_metrics['exec']:.3f}" | |
else: | |
yield f"{category}: No data available" | |
else: | |
yield "No evaluation metrics returned." | |
except Exception as e: | |
yield f"An unexpected error occurred: {str(e)}" | |
yield f"Error traceback: {traceback.format_exc()}" | |
if __name__ == "__main__": | |
model_name = input("Enter the model name: ") | |
prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort" | |
for result in run_evaluation(model_name, prompt_format): | |
print(result, flush=True) |