Spaces:
Running
Running
import os | |
import sys | |
from pathlib import Path | |
from datetime import datetime | |
import json | |
import traceback | |
# Add the necessary directories to the Python path | |
current_dir = Path(__file__).resolve().parent | |
duckdb_nsql_dir = current_dir / 'duckdb-nsql' | |
eval_dir = duckdb_nsql_dir / 'eval' | |
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)]) | |
# Import necessary functions and classes | |
from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql | |
from eval.evaluate import evaluate, compute_metrics, get_to_print | |
from eval.evaluate import test_suite_evaluation, read_tables_json | |
from eval.schema import TextToSQLParams, Table | |
AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys()) | |
def run_prediction(model_name, prompt_format, output_file): | |
dataset_path = str(eval_dir / "data/dev.json") | |
table_meta_path = str(eval_dir / "data/tables.json") | |
stop_tokens = [';'] | |
max_tokens = 30000 | |
temperature = 0.1 | |
num_beams = -1 | |
manifest_client = "openrouter" | |
manifest_engine = model_name | |
manifest_connection = "http://localhost:5000" | |
overwrite_manifest = True | |
parallel = False | |
yield "Starting prediction..." | |
try: | |
# Initialize necessary components | |
data_formatter = DefaultLoader() | |
prompt_formatter = PROMPT_FORMATTERS[prompt_format]() | |
# Load manifest | |
manifest = get_manifest( | |
manifest_client=manifest_client, | |
manifest_connection=manifest_connection, | |
manifest_engine=manifest_engine, | |
) | |
# Load data | |
data = data_formatter.load_data(dataset_path) | |
db_to_tables = data_formatter.load_table_metadata(table_meta_path) | |
# Prepare input for generate_sql | |
text_to_sql_inputs = [] | |
for input_question in data: | |
question = input_question["question"] | |
db_id = input_question.get("db_id", "none") | |
if db_id != "none": | |
table_params = list(db_to_tables.get(db_id, {}).values()) | |
else: | |
table_params = [] | |
if len(table_params) == 0: | |
yield f"[red] WARNING: No tables found for {db_id} [/red]" | |
text_to_sql_inputs.append(TextToSQLParams( | |
instruction=question, | |
database=db_id, | |
tables=table_params, | |
)) | |
# Generate SQL | |
generated_sqls = generate_sql( | |
manifest=manifest, | |
text_to_sql_in=text_to_sql_inputs, | |
retrieved_docs=[[] for _ in text_to_sql_inputs], # Assuming no retrieved docs | |
prompt_formatter=prompt_formatter, | |
stop_tokens=stop_tokens, | |
overwrite_manifest=overwrite_manifest, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
num_beams=num_beams, | |
parallel=parallel | |
) | |
# Save results | |
with output_file.open('w') as f: | |
for original_data, (sql, _) in zip(data, generated_sqls): | |
output = {**original_data, "pred": sql} | |
json.dump(output, f) | |
f.write('\n') | |
yield f"Prediction completed. Results saved to {output_file}" | |
except Exception as e: | |
yield f"Prediction failed with error: {str(e)}" | |
yield f"Error traceback: {traceback.format_exc()}" | |
def run_evaluation(model_name, prompt_format="duckdbinstgraniteshort"): | |
if "OPENROUTER_API_KEY" not in os.environ: | |
yield "Error: OPENROUTER_API_KEY not found in environment variables." | |
return | |
try: | |
# Set up the arguments | |
dataset_path = str(eval_dir / "data/dev.json") | |
table_meta_path = str(eval_dir / "data/tables.json") | |
output_dir = eval_dir / "output" | |
yield f"Using model: {model_name}" | |
yield f"Using prompt format: {prompt_format}" | |
output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json" | |
# Ensure the output directory exists | |
output_dir.mkdir(parents=True, exist_ok=True) | |
if output_file.exists(): | |
yield f"Prediction file already exists: {output_file}" | |
yield "Skipping prediction step and proceeding to evaluation." | |
else: | |
# Run prediction | |
for output in run_prediction(model_name, prompt_format, output_file): | |
yield output | |
# Run evaluation | |
yield "Starting evaluation..." | |
# Set up evaluation arguments | |
gold_path = Path(dataset_path) | |
db_dir = str(eval_dir / "data/databases/") | |
tables_path = Path(table_meta_path) | |
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) | |
db_schemas = read_tables_json(str(tables_path)) | |
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) | |
pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()] | |
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] | |
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] | |
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] | |
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] | |
pred_sqls = [p["pred"] for p in pred_sqls_dict] | |
categories = [p.get("category", "") for p in gold_sqls_dict] | |
yield "Computing metrics..." | |
metrics = compute_metrics( | |
gold_sqls=gold_sqls, | |
pred_sqls=pred_sqls, | |
gold_dbs=gold_dbs, | |
setup_sqls=setup_sqls, | |
validate_sqls=validate_sqls, | |
kmaps=kmaps, | |
db_schemas=db_schemas, | |
database_dir=db_dir, | |
lowercase_schema_match=False, | |
model_name=model_name, | |
categories=categories, | |
) | |
yield "Evaluation completed." | |
if metrics: | |
yield "Overall Results:" | |
overall_metrics = metrics['exec']['all'] | |
yield f"Count: {overall_metrics['count']}" | |
yield f"Execution Accuracy: {overall_metrics['exec']:.3f}" | |
yield f"Exact Match Accuracy: {overall_metrics['exact']:.3f}" | |
yield f"Equality: {metrics['equality']['equality']:.3f}" | |
yield f"Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}" | |
yield "\nResults by Category:" | |
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] | |
for category in categories: | |
if category in metrics['exec']: | |
yield f"\n{category}:" | |
category_metrics = metrics['exec'][category] | |
yield f"Count: {category_metrics['count']}" | |
yield f"Execution Accuracy: {category_metrics['exec']:.3f}" | |
else: | |
yield f"\n{category}: No data available" | |
else: | |
yield "No evaluation metrics returned." | |
except Exception as e: | |
yield f"An unexpected error occurred: {str(e)}" | |
yield f"Error traceback: {traceback.format_exc()}" | |
if __name__ == "__main__": | |
model_name = input("Enter the model name: ") | |
prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort" | |
for result in run_evaluation(model_name, prompt_format): | |
print(result, flush=True) |