File size: 7,404 Bytes
b9dc6d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e95b047
b9dc6d6
 
e95b047
b9dc6d6
 
 
801bcbb
b9dc6d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a91b08
b9dc6d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import sys
from pathlib import Path
from datetime import datetime
import json
import traceback

# Add the necessary directories to the Python path
current_dir = Path(__file__).resolve().parent
duckdb_nsql_dir = current_dir / 'duckdb-nsql'
eval_dir = duckdb_nsql_dir / 'eval'
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])

# Import necessary functions and classes
from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql
from eval.evaluate import evaluate, compute_metrics, get_to_print
from eval.evaluate import test_suite_evaluation, read_tables_json
from eval.schema import TextToSQLParams, Table

AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())

def run_prediction(model_name, prompt_format, output_file):
    dataset_path = str(eval_dir / "data/dev.json")
    table_meta_path = str(eval_dir / "data/tables.json")
    stop_tokens = [';']
    max_tokens = 30000
    temperature = 0.1
    num_beams = -1
    manifest_client = "openrouter"
    manifest_engine = model_name
    manifest_connection = "http://localhost:5000"
    overwrite_manifest = True
    parallel = False

    yield "Starting prediction..."

    try:
        # Initialize necessary components
        data_formatter = DefaultLoader()
        prompt_formatter = PROMPT_FORMATTERS[prompt_format]()

        # Load manifest
        manifest = get_manifest(
            manifest_client=manifest_client,
            manifest_connection=manifest_connection,
            manifest_engine=manifest_engine,
        )

        # Load data
        data = data_formatter.load_data(dataset_path)
        db_to_tables = data_formatter.load_table_metadata(table_meta_path)

        # Prepare input for generate_sql
        text_to_sql_inputs = []
        for input_question in data:
            question = input_question["question"]
            db_id = input_question.get("db_id", "none")
            if db_id != "none":
                table_params = list(db_to_tables.get(db_id, {}).values())
            else:
                table_params = []

            if len(table_params) == 0:
                yield f"[red] WARNING: No tables found for {db_id} [/red]"

            text_to_sql_inputs.append(TextToSQLParams(
                instruction=question,
                database=db_id,
                tables=table_params,
            ))

        # Generate SQL
        generated_sqls = generate_sql(
            manifest=manifest,
            text_to_sql_in=text_to_sql_inputs,
            retrieved_docs=[[] for _ in text_to_sql_inputs],  # Assuming no retrieved docs
            prompt_formatter=prompt_formatter,
            stop_tokens=stop_tokens,
            overwrite_manifest=overwrite_manifest,
            max_tokens=max_tokens,
            temperature=temperature,
            num_beams=num_beams,
            parallel=parallel
        )

        # Save results
        with output_file.open('w') as f:
            for original_data, (sql, _) in zip(data, generated_sqls):
                output = {**original_data, "pred": sql}
                json.dump(output, f)
                f.write('\n')

        yield f"Prediction completed. Results saved to {output_file}"
    except Exception as e:
        yield f"Prediction failed with error: {str(e)}"
        yield f"Error traceback: {traceback.format_exc()}"

def run_evaluation(model_name, prompt_format="duckdbinstgraniteshort"):
    if "OPENROUTER_API_KEY" not in os.environ:
        yield "Error: OPENROUTER_API_KEY not found in environment variables."
        return

    try:
        # Set up the arguments
        dataset_path = str(eval_dir / "data/dev.json")
        table_meta_path = str(eval_dir / "data/tables.json")
        output_dir = eval_dir / "output"

        yield f"Using model: {model_name}"
        yield f"Using prompt format: {prompt_format}"

        output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"

        # Ensure the output directory exists
        output_dir.mkdir(parents=True, exist_ok=True)

        if output_file.exists():
            yield f"Prediction file already exists: {output_file}"
            yield "Skipping prediction step and proceeding to evaluation."
        else:
            # Run prediction
            for output in run_prediction(model_name, prompt_format, output_file):
                yield output

        # Run evaluation
        yield "Starting evaluation..."

        # Set up evaluation arguments
        gold_path = Path(dataset_path)
        db_dir = str(eval_dir / "data/databases/")
        tables_path = Path(table_meta_path)

        kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
        db_schemas = read_tables_json(str(tables_path))

        gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
        pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()]

        gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
        setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
        validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
        gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
        pred_sqls = [p["pred"] for p in pred_sqls_dict]
        categories = [p.get("category", "") for p in gold_sqls_dict]

        yield "Computing metrics..."
        metrics = compute_metrics(
            gold_sqls=gold_sqls,
            pred_sqls=pred_sqls,
            gold_dbs=gold_dbs,
            setup_sqls=setup_sqls,
            validate_sqls=validate_sqls,
            kmaps=kmaps,
            db_schemas=db_schemas,
            database_dir=db_dir,
            lowercase_schema_match=False,
            model_name=model_name,
            categories=categories,
        )

        yield "Evaluation completed."

        if metrics:
            yield "Overall Results:"
            overall_metrics = metrics['exec']['all']
            yield f"Count: {overall_metrics['count']}"
            yield f"Execution Accuracy: {overall_metrics['exec']:.3f}"
            yield f"Exact Match Accuracy: {overall_metrics['exact']:.3f}"
            yield f"Equality: {metrics['equality']['equality']:.3f}"
            yield f"Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}"

            yield "\nResults by Category:"
            categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all']
            
            for category in categories:
                if category in metrics['exec']:
                    yield f"\n{category}:"
                    category_metrics = metrics['exec'][category]
                    yield f"Count: {category_metrics['count']}"
                    yield f"Execution Accuracy: {category_metrics['exec']:.3f}"
                else:
                    yield f"\n{category}: No data available"
        else:
            yield "No evaluation metrics returned."
    except Exception as e:
        yield f"An unexpected error occurred: {str(e)}"
        yield f"Error traceback: {traceback.format_exc()}"

if __name__ == "__main__":
    model_name = input("Enter the model name: ")
    prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort"
    for result in run_evaluation(model_name, prompt_format):
        print(result, flush=True)