tdoehmen commited on
Commit
3a1fd8c
1 Parent(s): 4b67f9f

added inference api

Browse files
Files changed (2) hide show
  1. app.py +8 -3
  2. evaluation_logic.py +7 -4
app.py CHANGED
@@ -1,15 +1,20 @@
1
  import gradio as gr
2
  from evaluation_logic import run_evaluation, AVAILABLE_PROMPT_FORMATS
3
 
4
- def gradio_run_evaluation(model_name, prompt_format):
5
  output = []
6
- for result in run_evaluation(str(model_name).strip(), prompt_format):
7
  output.append(result)
8
  yield "\n".join(output)
9
 
10
  with gr.Blocks() as demo:
11
  gr.Markdown("# DuckDB SQL Evaluation App")
12
 
 
 
 
 
 
13
  model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
14
  prompt_format = gr.Dropdown(
15
  label="Prompt Format",
@@ -19,6 +24,6 @@ with gr.Blocks() as demo:
19
  start_btn = gr.Button("Start Evaluation")
20
  output = gr.Textbox(label="Output", lines=20)
21
 
22
- start_btn.click(fn=gradio_run_evaluation, inputs=[model_name, prompt_format], outputs=output)
23
 
24
  demo.queue().launch()
 
1
  import gradio as gr
2
  from evaluation_logic import run_evaluation, AVAILABLE_PROMPT_FORMATS
3
 
4
+ def gradio_run_evaluation(inference_api, model_name, prompt_format):
5
  output = []
6
+ for result in run_evaluation(inference_api, str(model_name).strip(), prompt_format):
7
  output.append(result)
8
  yield "\n".join(output)
9
 
10
  with gr.Blocks() as demo:
11
  gr.Markdown("# DuckDB SQL Evaluation App")
12
 
13
+ inference_api = gr.Dropdown(
14
+ label="Inference API",
15
+ choices=['openrouter', 'hf_inference_api'], #AVAILABLE_PROMPT_FORMATS,
16
+ value="openrouter"
17
+ )
18
  model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
19
  prompt_format = gr.Dropdown(
20
  label="Prompt Format",
 
24
  start_btn = gr.Button("Start Evaluation")
25
  output = gr.Textbox(label="Output", lines=20)
26
 
27
+ start_btn.click(fn=gradio_run_evaluation, inputs=[inference_api, model_name, prompt_format], outputs=output)
28
 
29
  demo.queue().launch()
evaluation_logic.py CHANGED
@@ -19,14 +19,14 @@ from eval.schema import TextToSQLParams, Table
19
 
20
  AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())
21
 
22
- def run_prediction(model_name, prompt_format, output_file):
23
  dataset_path = str(eval_dir / "data/dev.json")
24
  table_meta_path = str(eval_dir / "data/tables.json")
25
  stop_tokens = [';']
26
  max_tokens = 30000
27
  temperature = 0.1
28
  num_beams = -1
29
- manifest_client = "openrouter"
30
  manifest_engine = model_name
31
  manifest_connection = "http://localhost:5000"
32
  overwrite_manifest = True
@@ -95,10 +95,13 @@ def run_prediction(model_name, prompt_format, output_file):
95
  yield f"Prediction failed with error: {str(e)}"
96
  yield f"Error traceback: {traceback.format_exc()}"
97
 
98
- def run_evaluation(model_name, prompt_format="duckdbinstgraniteshort"):
99
  if "OPENROUTER_API_KEY" not in os.environ:
100
  yield "Error: OPENROUTER_API_KEY not found in environment variables."
101
  return
 
 
 
102
 
103
  try:
104
  # Set up the arguments
@@ -119,7 +122,7 @@ def run_evaluation(model_name, prompt_format="duckdbinstgraniteshort"):
119
  yield "Skipping prediction step and proceeding to evaluation."
120
  else:
121
  # Run prediction
122
- for output in run_prediction(model_name, prompt_format, output_file):
123
  yield output
124
 
125
  # Run evaluation
 
19
 
20
  AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())
21
 
22
+ def run_prediction(inference_api, model_name, prompt_format, output_file):
23
  dataset_path = str(eval_dir / "data/dev.json")
24
  table_meta_path = str(eval_dir / "data/tables.json")
25
  stop_tokens = [';']
26
  max_tokens = 30000
27
  temperature = 0.1
28
  num_beams = -1
29
+ manifest_client = inference_api
30
  manifest_engine = model_name
31
  manifest_connection = "http://localhost:5000"
32
  overwrite_manifest = True
 
95
  yield f"Prediction failed with error: {str(e)}"
96
  yield f"Error traceback: {traceback.format_exc()}"
97
 
98
+ def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort"):
99
  if "OPENROUTER_API_KEY" not in os.environ:
100
  yield "Error: OPENROUTER_API_KEY not found in environment variables."
101
  return
102
+ if "HF_TOKEN" not in os.environ:
103
+ yield "Error: HF_TOKEN not found in environment variables."
104
+ return
105
 
106
  try:
107
  # Set up the arguments
 
122
  yield "Skipping prediction step and proceeding to evaluation."
123
  else:
124
  # Run prediction
125
+ for output in run_prediction(inference_api, model_name, prompt_format, output_file):
126
  yield output
127
 
128
  # Run evaluation