tdoehmen commited on
Commit
4b67f9f
1 Parent(s): 6da1916

use concurrent futures instead of signal

Browse files
Files changed (2) hide show
  1. duckdb-nsql/eval/evaluate.py +29 -29
  2. evaluation_logic.py +2 -2
duckdb-nsql/eval/evaluate.py CHANGED
@@ -12,6 +12,7 @@ import click
12
  import pandas as pd
13
  from rich.console import Console
14
  from tqdm.auto import tqdm
 
15
 
16
  sys.path.append(os.path.join(os.path.dirname(__file__), "."))
17
  # from metrics.spider import evaluation as spider_evaluation # type: ignore # noqa: E402
@@ -113,15 +114,24 @@ def compute_exact_match_metric(
113
  return exact_match
114
 
115
 
 
 
 
 
 
 
 
 
 
116
  def compute_test_suite_metric(
117
- predictions: list,
118
- references: list,
119
- gold_dbs: list,
120
- setup_sqls: list,
121
- validate_sqls: list,
122
- kmaps: dict,
123
- db_dir: str,
124
- categories: list[str] = None,
125
  ) -> tuple[Any, list[int | None]]:
126
  """Compute test suite execution metric."""
127
  evaluator = test_suite_evaluation.Evaluator(
@@ -135,37 +145,27 @@ def compute_test_suite_metric(
135
  # Only used for Sparc/CoSQL
136
  turn_scores: dict[str, list] = {"exec": [], "exact": []}
137
  by_row_metrics: list[int | None] = []
 
138
  for prediction, reference, gold_db, setup_sql, validate_sql, category in tqdm(
139
- zip(predictions, references, gold_dbs, setup_sqls, validate_sqls, categories),
140
- total=len(predictions),
141
  ):
142
  turn_idx = 0
143
  # skip final utterance-query pairs
144
  if turn_idx < 0:
145
  continue
146
 
147
- # Register the timeout handler function
148
- signal.signal(signal.SIGALRM, timeout_handler)
149
- signal.alarm(TIMEOUT_SECONDS)
150
-
151
- try:
152
- ex_metrics = evaluator.evaluate_one(
153
- gold_db,
154
- reference,
155
- prediction,
156
- setup_sql,
157
- validate_sql,
158
- turn_scores,
159
- idx=turn_idx,
160
- category=category,
161
- )
162
- signal.alarm(0)
163
 
 
164
  by_row_metrics.append(int(ex_metrics["exec"]))
165
- except Exception as e:
166
- raise e
167
  by_row_metrics.append(None)
168
- pass
169
  evaluator.finalize()
170
  return evaluator.scores, by_row_metrics
171
 
 
12
  import pandas as pd
13
  from rich.console import Console
14
  from tqdm.auto import tqdm
15
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError
16
 
17
  sys.path.append(os.path.join(os.path.dirname(__file__), "."))
18
  # from metrics.spider import evaluation as spider_evaluation # type: ignore # noqa: E402
 
114
  return exact_match
115
 
116
 
117
+ def evaluate_with_timeout(evaluator, *args, timeout):
118
+ with ThreadPoolExecutor(max_workers=1) as executor:
119
+ future = executor.submit(evaluator.evaluate_one, *args)
120
+ try:
121
+ result = future.result(timeout=timeout)
122
+ except TimeoutError:
123
+ result = None
124
+ return result
125
+
126
  def compute_test_suite_metric(
127
+ predictions: list,
128
+ references: list,
129
+ gold_dbs: list,
130
+ setup_sqls: list,
131
+ validate_sqls: list,
132
+ kmaps: dict,
133
+ db_dir: str,
134
+ categories: list[str] = None,
135
  ) -> tuple[Any, list[int | None]]:
136
  """Compute test suite execution metric."""
137
  evaluator = test_suite_evaluation.Evaluator(
 
145
  # Only used for Sparc/CoSQL
146
  turn_scores: dict[str, list] = {"exec": [], "exact": []}
147
  by_row_metrics: list[int | None] = []
148
+
149
  for prediction, reference, gold_db, setup_sql, validate_sql, category in tqdm(
150
+ zip(predictions, references, gold_dbs, setup_sqls, validate_sqls, categories),
151
+ total=len(predictions),
152
  ):
153
  turn_idx = 0
154
  # skip final utterance-query pairs
155
  if turn_idx < 0:
156
  continue
157
 
158
+ # Use the new function to evaluate with timeout
159
+ ex_metrics = evaluate_with_timeout(
160
+ evaluator, gold_db, reference, prediction, setup_sql, validate_sql,
161
+ turn_scores, timeout=TIMEOUT_SECONDS
162
+ )
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ if ex_metrics:
165
  by_row_metrics.append(int(ex_metrics["exec"]))
166
+ else:
 
167
  by_row_metrics.append(None)
168
+
169
  evaluator.finalize()
170
  return evaluator.scores, by_row_metrics
171
 
evaluation_logic.py CHANGED
@@ -60,8 +60,8 @@ def run_prediction(model_name, prompt_format, output_file):
60
  else:
61
  table_params = []
62
 
63
- if len(table_params) == 0:
64
- yield f"[red] WARNING: No tables found for {db_id} [/red]"
65
 
66
  text_to_sql_inputs.append(TextToSQLParams(
67
  instruction=question,
 
60
  else:
61
  table_params = []
62
 
63
+ #if len(table_params) == 0:
64
+ #yield f"[red] WARNING: No tables found for {db_id} [/red]"
65
 
66
  text_to_sql_inputs.append(TextToSQLParams(
67
  instruction=question,