Tristan Thrush commited on
Commit
a16df4c
1 Parent(s): 20f3a68

black format

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -34,7 +34,18 @@ TASK_TO_ID = {
34
 
35
  TASK_TO_DEFAULT_METRICS = {
36
  "binary_classification": ["f1", "precision", "recall", "auc", "accuracy"],
37
- "multi_class_classification": ["f1_micro", "f1_macro", "f1_weighted", "precision_macro", "precision_micro", "precision_weighted", "recall_macro", "recall_micro", "recall_weighted", "accuracy"],
 
 
 
 
 
 
 
 
 
 
 
38
  "entity_extraction": ["precision", "recall", "f1", "accuracy"],
39
  "extractive_question_answering": [],
40
  "translation": ["sacrebleu", "gen_len"],
@@ -43,6 +54,7 @@ TASK_TO_DEFAULT_METRICS = {
43
 
44
  SUPPORTED_TASKS = list(TASK_TO_ID.keys())
45
 
 
46
  @st.cache
47
  def get_supported_metrics():
48
  metrics = list_metrics()
@@ -55,10 +67,7 @@ def get_supported_metrics():
55
  print("Skipping the following metric, which cannot load:", metric)
56
 
57
  argspec = inspect.getfullargspec(metric_func.compute)
58
- if (
59
- "references" in argspec.kwonlyargs
60
- and "predictions" in argspec.kwonlyargs
61
- ):
62
  # We require that "references" and "predictions" are arguments
63
  # to the metric function. We also require that the other arguments
64
  # besides "references" and "predictions" have defaults and so do not
@@ -74,6 +83,7 @@ def get_supported_metrics():
74
  supported_metrics.append(metric)
75
  return supported_metrics
76
 
 
77
  supported_metrics = get_supported_metrics()
78
 
79
 
@@ -294,17 +304,23 @@ with st.form(key="form"):
294
 
295
  compatible_models = get_compatible_models(selected_task, selected_dataset)
296
  st.markdown("The following metrics will be computed")
297
- html_string = " ".join([
298
- "<div style=\"padding-right:5px;padding-left:5px;padding-top:5px;padding-bottom:5px;float:left\">"
299
- + "<div style=\"background-color:#D3D3D3;border-radius:5px;display:inline-block;padding-right:5px;padding-left:5px;color:white\">"
300
- + metric + "</div></div>" for metric in TASK_TO_DEFAULT_METRICS[selected_task]
301
- ])
 
 
 
 
302
  st.markdown(html_string, unsafe_allow_html=True)
303
  selected_metrics = st.multiselect(
304
  "(Optional) Select additional metrics",
305
  list(set(supported_metrics) - set(TASK_TO_DEFAULT_METRICS[selected_task])),
306
  )
307
- st.info("Note: user-selected metrics will be run with their default arguments from [here](https://github.com/huggingface/datasets/tree/master/metrics)")
 
 
308
 
309
  selected_models = st.multiselect("Select the models you wish to evaluate", compatible_models)
310
  print("Selected models:", selected_models)
 
34
 
35
  TASK_TO_DEFAULT_METRICS = {
36
  "binary_classification": ["f1", "precision", "recall", "auc", "accuracy"],
37
+ "multi_class_classification": [
38
+ "f1_micro",
39
+ "f1_macro",
40
+ "f1_weighted",
41
+ "precision_macro",
42
+ "precision_micro",
43
+ "precision_weighted",
44
+ "recall_macro",
45
+ "recall_micro",
46
+ "recall_weighted",
47
+ "accuracy",
48
+ ],
49
  "entity_extraction": ["precision", "recall", "f1", "accuracy"],
50
  "extractive_question_answering": [],
51
  "translation": ["sacrebleu", "gen_len"],
 
54
 
55
  SUPPORTED_TASKS = list(TASK_TO_ID.keys())
56
 
57
+
58
  @st.cache
59
  def get_supported_metrics():
60
  metrics = list_metrics()
 
67
  print("Skipping the following metric, which cannot load:", metric)
68
 
69
  argspec = inspect.getfullargspec(metric_func.compute)
70
+ if "references" in argspec.kwonlyargs and "predictions" in argspec.kwonlyargs:
 
 
 
71
  # We require that "references" and "predictions" are arguments
72
  # to the metric function. We also require that the other arguments
73
  # besides "references" and "predictions" have defaults and so do not
 
83
  supported_metrics.append(metric)
84
  return supported_metrics
85
 
86
+
87
  supported_metrics = get_supported_metrics()
88
 
89
 
 
304
 
305
  compatible_models = get_compatible_models(selected_task, selected_dataset)
306
  st.markdown("The following metrics will be computed")
307
+ html_string = " ".join(
308
+ [
309
+ '<div style="padding-right:5px;padding-left:5px;padding-top:5px;padding-bottom:5px;float:left">'
310
+ + '<div style="background-color:#D3D3D3;border-radius:5px;display:inline-block;padding-right:5px;padding-left:5px;color:white">'
311
+ + metric
312
+ + "</div></div>"
313
+ for metric in TASK_TO_DEFAULT_METRICS[selected_task]
314
+ ]
315
+ )
316
  st.markdown(html_string, unsafe_allow_html=True)
317
  selected_metrics = st.multiselect(
318
  "(Optional) Select additional metrics",
319
  list(set(supported_metrics) - set(TASK_TO_DEFAULT_METRICS[selected_task])),
320
  )
321
+ st.info(
322
+ "Note: user-selected metrics will be run with their default arguments from [here](https://github.com/huggingface/datasets/tree/master/metrics)"
323
+ )
324
 
325
  selected_models = st.multiselect("Select the models you wish to evaluate", compatible_models)
326
  print("Selected models:", selected_models)