ZeroCommand commited on
Commit
666860b
·
verified ·
1 Parent(s): 3ca571c

GSK-2836-change-ui-based-on-QA-fix-splits-bug (#119)

Browse files

- fix ui and splits not updated bug (261e3ff7b4e4b2a721516698fc03bde46ee5990c)
- clean code (07461b51a6c9c15efc7988a8ca6d2127d88d4359)

app_debug.py CHANGED
@@ -74,12 +74,12 @@ def get_demo():
74
  value=get_queue_status,
75
  every=5,
76
  )
77
- with gr.Accordion(label="Log Files", open=False):
78
- with gr.Row():
79
- gr.Files(value=get_log_files, label="Log Files", every=10)
80
  with gr.Row():
81
  gr.Textbox(
82
  value=get_logs_file, every=0.5, lines=10, visible=True, label="Current Log File"
83
  )
 
 
84
  with gr.Accordion(label="Config Files", open=False):
85
  gr.Files(value=get_config_files, label="Config Files", every=10)
 
74
  value=get_queue_status,
75
  every=5,
76
  )
77
+ with gr.Accordion(label="Log Files", open=True):
 
 
78
  with gr.Row():
79
  gr.Textbox(
80
  value=get_logs_file, every=0.5, lines=10, visible=True, label="Current Log File"
81
  )
82
+ with gr.Row():
83
+ gr.Files(value=get_log_files, label="Log Files", every=10)
84
  with gr.Accordion(label="Config Files", open=False):
85
  gr.Files(value=get_config_files, label="Config Files", every=10)
app_leaderboard.py CHANGED
@@ -88,11 +88,29 @@ def get_demo(leaderboard_tab):
88
  dataset_ids = get_dataset_ids(records)
89
 
90
  column_names = records.columns.tolist()
91
- default_columns = ["model_id", "dataset_id", "total_issues", "report_link"]
 
 
92
  default_df = records[default_columns] # extract columns selected
93
  types = get_types(default_df)
94
  display_df = get_display_df(default_df) # the styled dataframe to display
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  with gr.Row():
97
  task_select = gr.Dropdown(
98
  label="Task",
@@ -110,42 +128,35 @@ def get_demo(leaderboard_tab):
110
  interactive=True,
111
  )
112
 
113
- with gr.Row():
114
- columns_select = gr.CheckboxGroup(
115
- label="Show columns",
116
- choices=column_names,
117
- value=default_columns,
118
- interactive=True,
119
- )
120
-
121
  with gr.Row():
122
  leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
123
 
124
- def update_leaderboard_records(model_id, dataset_id, columns, task):
125
  global update_time
126
  if datetime.datetime.now() - update_time < datetime.timedelta(minutes=10):
127
  return gr.update()
128
  update_time = datetime.datetime.now()
129
  logger.info("Updating leaderboard records")
130
  leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
131
- return filter_table(model_id, dataset_id, columns, task)
132
 
133
  leaderboard_tab.select(
134
  fn=update_leaderboard_records,
135
- inputs=[model_select, dataset_select, columns_select, task_select],
136
  outputs=[leaderboard_df])
137
 
138
  @gr.on(
139
  triggers=[
140
  model_select.change,
141
  dataset_select.change,
142
- columns_select.change,
 
143
  task_select.change,
144
  ],
145
- inputs=[model_select, dataset_select, columns_select, task_select],
146
  outputs=[leaderboard_df],
147
  )
148
- def filter_table(model_id, dataset_id, columns, task):
149
  logger.info("Filtering leaderboard records")
150
  records = leaderboard.records
151
  # filter the table based on task
@@ -156,8 +167,9 @@ def get_demo(leaderboard_tab):
156
  if dataset_id and dataset_id != "Any":
157
  df = df[(df["dataset_id"] == dataset_id)]
158
 
159
- # filter the table based on the columns
160
- df = df[columns]
 
161
  types = get_types(df)
162
  display_df = get_display_df(df)
163
  return gr.update(value=display_df, datatype=types, interactive=False)
 
88
  dataset_ids = get_dataset_ids(records)
89
 
90
  column_names = records.columns.tolist()
91
+ issue_columns = column_names[:11]
92
+ info_columns = column_names[15:]
93
+ default_columns = ["dataset_id", "total_issues", "report_link"]
94
  default_df = records[default_columns] # extract columns selected
95
  types = get_types(default_df)
96
  display_df = get_display_df(default_df) # the styled dataframe to display
97
 
98
+ with gr.Row():
99
+ with gr.Column():
100
+ issue_columns_select = gr.CheckboxGroup(
101
+ label="Issue Columns",
102
+ choices=issue_columns,
103
+ value=[],
104
+ interactive=True,
105
+ )
106
+ with gr.Column():
107
+ info_columns_select = gr.CheckboxGroup(
108
+ label="Info Columns",
109
+ choices=info_columns,
110
+ value=default_columns,
111
+ interactive=True,
112
+ )
113
+
114
  with gr.Row():
115
  task_select = gr.Dropdown(
116
  label="Task",
 
128
  interactive=True,
129
  )
130
 
 
 
 
 
 
 
 
 
131
  with gr.Row():
132
  leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
133
 
134
+ def update_leaderboard_records(model_id, dataset_id, issue_columns, info_columns, task):
135
  global update_time
136
  if datetime.datetime.now() - update_time < datetime.timedelta(minutes=10):
137
  return gr.update()
138
  update_time = datetime.datetime.now()
139
  logger.info("Updating leaderboard records")
140
  leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
141
+ return filter_table(model_id, dataset_id, issue_columns, info_columns, task)
142
 
143
  leaderboard_tab.select(
144
  fn=update_leaderboard_records,
145
+ inputs=[model_select, dataset_select, issue_columns_select, info_columns_select, task_select],
146
  outputs=[leaderboard_df])
147
 
148
  @gr.on(
149
  triggers=[
150
  model_select.change,
151
  dataset_select.change,
152
+ issue_columns_select.change,
153
+ info_columns_select.change,
154
  task_select.change,
155
  ],
156
+ inputs=[model_select, dataset_select, issue_columns_select, info_columns_select, task_select],
157
  outputs=[leaderboard_df],
158
  )
159
+ def filter_table(model_id, dataset_id, issue_columns, info_columns, task):
160
  logger.info("Filtering leaderboard records")
161
  records = leaderboard.records
162
  # filter the table based on task
 
167
  if dataset_id and dataset_id != "Any":
168
  df = df[(df["dataset_id"] == dataset_id)]
169
 
170
+ # filter the table based on the columns
171
+ issue_columns.sort()
172
+ df = df[["model_id"] + info_columns + issue_columns]
173
  types = get_types(df)
174
  display_df = get_display_df(df)
175
  return gr.update(value=display_df, datatype=types, interactive=False)
app_text_classification.py CHANGED
@@ -6,6 +6,7 @@ from io_utils import read_scanners, write_scanners
6
  from text_classification_ui_helpers import (
7
  get_related_datasets_from_leaderboard,
8
  align_columns_and_show_prediction,
 
9
  check_dataset,
10
  show_hf_token_info,
11
  precheck_model_ds_enable_example_btn,
@@ -70,7 +71,7 @@ def get_demo():
70
  with gr.Row():
71
  example_input = gr.HTML(visible=False)
72
  with gr.Row():
73
- example_prediction = gr.Label(label="Model Prediction Sample", visible=False)
74
 
75
  with gr.Row():
76
  with gr.Accordion(
@@ -94,7 +95,7 @@ def get_demo():
94
 
95
  run_inference = gr.Checkbox(value=True, label="Run with Inference API")
96
  inference_token = gr.Textbox(
97
- placeholder="hf-xxxxxxxxxxxxxxxxxxxx",
98
  value="",
99
  label="HF Token for Inference API",
100
  visible=True,
@@ -109,7 +110,7 @@ def get_demo():
109
  )
110
 
111
  with gr.Accordion(label="Scanner Advance Config (optional)", open=False):
112
- scanners = gr.CheckboxGroup(label="Scan Settings", visible=True)
113
 
114
  @gr.on(triggers=[uid_label.change], inputs=[uid_label], outputs=[scanners])
115
  def get_scanners(uid):
@@ -146,19 +147,17 @@ def get_demo():
146
  fn=get_related_datasets_from_leaderboard,
147
  inputs=[model_id_input],
148
  outputs=[dataset_id_input],
149
- ).then(
150
- fn=check_dataset,
151
- inputs=[dataset_id_input],
152
- outputs=[dataset_config_input, dataset_split_input, loading_status]
153
  )
154
 
155
  gr.on(
156
- triggers=[dataset_id_input.change],
157
  fn=check_dataset,
158
  inputs=[dataset_id_input],
159
  outputs=[dataset_config_input, dataset_split_input, loading_status]
160
  )
161
 
 
 
162
  gr.on(
163
  triggers=[model_id_input.change, dataset_id_input.change, dataset_config_input.change],
164
  fn=empty_column_mapping,
 
6
  from text_classification_ui_helpers import (
7
  get_related_datasets_from_leaderboard,
8
  align_columns_and_show_prediction,
9
+ get_dataset_splits,
10
  check_dataset,
11
  show_hf_token_info,
12
  precheck_model_ds_enable_example_btn,
 
71
  with gr.Row():
72
  example_input = gr.HTML(visible=False)
73
  with gr.Row():
74
+ example_prediction = gr.Label(label="Model Sample Prediction", visible=False)
75
 
76
  with gr.Row():
77
  with gr.Accordion(
 
95
 
96
  run_inference = gr.Checkbox(value=True, label="Run with Inference API")
97
  inference_token = gr.Textbox(
98
+ placeholder="hf_xxxxxxxxxxxxxxxxxxxx",
99
  value="",
100
  label="HF Token for Inference API",
101
  visible=True,
 
110
  )
111
 
112
  with gr.Accordion(label="Scanner Advance Config (optional)", open=False):
113
+ scanners = gr.CheckboxGroup(visible=True)
114
 
115
  @gr.on(triggers=[uid_label.change], inputs=[uid_label], outputs=[scanners])
116
  def get_scanners(uid):
 
147
  fn=get_related_datasets_from_leaderboard,
148
  inputs=[model_id_input],
149
  outputs=[dataset_id_input],
 
 
 
 
150
  )
151
 
152
  gr.on(
153
+ triggers=[dataset_id_input.input, dataset_id_input.select],
154
  fn=check_dataset,
155
  inputs=[dataset_id_input],
156
  outputs=[dataset_config_input, dataset_split_input, loading_status]
157
  )
158
 
159
+ dataset_config_input.change(fn=get_dataset_splits, inputs=[dataset_id_input, dataset_config_input], outputs=[dataset_split_input])
160
+
161
  gr.on(
162
  triggers=[model_id_input.change, dataset_id_input.change, dataset_config_input.change],
163
  fn=empty_column_mapping,
text_classification.py CHANGED
@@ -393,4 +393,21 @@ def check_hf_token_validity(hf_token):
393
  response = requests.get(AUTH_CHECK_URL, headers=headers)
394
  if response.status_code != 200:
395
  return False
396
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  response = requests.get(AUTH_CHECK_URL, headers=headers)
394
  if response.status_code != 200:
395
  return False
396
+ return True
397
+
398
+ def get_dataset_info_from_server(dataset_id):
399
+ url = "https://datasets-server.huggingface.co/splits?dataset=" + dataset_id
400
+ response = requests.get(url)
401
+ if response.status_code != 200:
402
+ return None
403
+ return response.json()
404
+
405
+ def get_dataset_splits(dataset_id, dataset_config):
406
+ dataset_info = get_dataset_info_from_server(dataset_id)
407
+ if dataset_info is None:
408
+ return None
409
+ try:
410
+ splits = dataset_info["splits"]
411
+ return [split["split"] for split in splits if split["config"] == dataset_config]
412
+ except Exception:
413
+ return None
text_classification_ui_helpers.py CHANGED
@@ -26,7 +26,7 @@ from wordings import (
26
  NOT_TEXT_CLASSIFICATION_MODEL_RAW,
27
  UNMATCHED_MODEL_DATASET_STYLED_ERROR,
28
  CHECK_LOG_SECTION_RAW,
29
- get_styled_input,
30
  get_dataset_fetch_error_raw,
31
  )
32
  import os
@@ -44,13 +44,20 @@ def get_related_datasets_from_leaderboard(model_id):
44
  datasets_unique = list(model_records["dataset_id"].unique())
45
 
46
  if len(datasets_unique) == 0:
47
- return gr.update(choices=[], value="")
48
 
49
- return gr.update(choices=datasets_unique, value="")
50
 
51
 
52
  logger = logging.getLogger(__file__)
53
 
 
 
 
 
 
 
 
54
 
55
  def check_dataset(dataset_id):
56
  logger.info(f"Loading {dataset_id}")
@@ -62,9 +69,7 @@ def check_dataset(dataset_id):
62
  gr.update(visible=False),
63
  ""
64
  )
65
- splits = datasets.get_dataset_split_names(
66
- dataset_id, configs[0], trust_remote_code=True
67
- )
68
  return (
69
  gr.update(choices=configs, value=configs[0], visible=True),
70
  gr.update(choices=splits, value=splits[0], visible=True),
@@ -169,21 +174,24 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels,
169
 
170
  def precheck_model_ds_enable_example_btn(
171
  model_id, dataset_id, dataset_config, dataset_split
172
- ):
173
  model_id = strip_model_id_from_url(model_id)
174
  model_task = check_model_task(model_id)
175
  preload_hf_inference_api(model_id)
 
176
  if model_task is None or model_task != "text-classification":
177
  gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
178
- return (gr.update(), gr.update(),"")
179
 
180
  if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
181
- return (gr.update(), gr.update(), "")
182
 
183
  try:
184
  ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
185
  df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
186
  ds_labels, ds_features = get_labels_and_features_from_dataset(ds[dataset_split])
 
 
187
 
188
  if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
189
  gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
@@ -193,7 +201,7 @@ def precheck_model_ds_enable_example_btn(
193
  except Exception as e:
194
  # Config or split wrong
195
  logger.warn(f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}")
196
- return (gr.update(interactive=False), gr.update(value=pd.DataFrame(), visible=False), "")
197
 
198
 
199
  def align_columns_and_show_prediction(
@@ -298,8 +306,8 @@ def align_columns_and_show_prediction(
298
  )
299
 
300
  return (
301
- gr.update(value=get_styled_input(prediction_input), visible=True),
302
- gr.update(value=prediction_response, visible=True),
303
  gr.update(visible=True, open=False),
304
  gr.update(interactive=(run_inference and inference_token != "")),
305
  "",
 
26
  NOT_TEXT_CLASSIFICATION_MODEL_RAW,
27
  UNMATCHED_MODEL_DATASET_STYLED_ERROR,
28
  CHECK_LOG_SECTION_RAW,
29
+ VALIDATED_MODEL_DATASET_STYLED,
30
  get_dataset_fetch_error_raw,
31
  )
32
  import os
 
44
  datasets_unique = list(model_records["dataset_id"].unique())
45
 
46
  if len(datasets_unique) == 0:
47
+ return gr.update(choices=[])
48
 
49
+ return gr.update(choices=datasets_unique)
50
 
51
 
52
  logger = logging.getLogger(__file__)
53
 
54
+ def get_dataset_splits(dataset_id, dataset_config):
55
+ try:
56
+ splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True)
57
+ return gr.update(choices=splits, value=splits[0], visible=True)
58
+ except Exception as e:
59
+ logger.warn(f"Check your dataset {dataset_id} and config {dataset_config}: {e}")
60
+ return gr.update(visible=False)
61
 
62
  def check_dataset(dataset_id):
63
  logger.info(f"Loading {dataset_id}")
 
69
  gr.update(visible=False),
70
  ""
71
  )
72
+ splits = datasets.get_dataset_split_names(dataset_id, configs[0], trust_remote_code=True)
 
 
73
  return (
74
  gr.update(choices=configs, value=configs[0], visible=True),
75
  gr.update(choices=splits, value=splits[0], visible=True),
 
174
 
175
  def precheck_model_ds_enable_example_btn(
176
  model_id, dataset_id, dataset_config, dataset_split
177
+ ):
178
  model_id = strip_model_id_from_url(model_id)
179
  model_task = check_model_task(model_id)
180
  preload_hf_inference_api(model_id)
181
+
182
  if model_task is None or model_task != "text-classification":
183
  gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
184
+ return (gr.update(interactive=False), gr.update(visible=False),"")
185
 
186
  if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
187
+ return (gr.update(interactive=False), gr.update(visible=False), "")
188
 
189
  try:
190
  ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
191
  df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
192
  ds_labels, ds_features = get_labels_and_features_from_dataset(ds[dataset_split])
193
+ if model_id == "" or model_id is None:
194
+ return (gr.update(interactive=False), gr.update(value=df, visible=True), "")
195
 
196
  if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
197
  gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
 
201
  except Exception as e:
202
  # Config or split wrong
203
  logger.warn(f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}")
204
+ return (gr.update(interactive=False), gr.update(visible=False), "")
205
 
206
 
207
  def align_columns_and_show_prediction(
 
306
  )
307
 
308
  return (
309
+ gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
310
+ gr.update(value=prediction_response, label=prediction_input, visible=True),
311
  gr.update(visible=True, open=False),
312
  gr.update(interactive=(run_inference and inference_token != "")),
313
  "",
wordings.py CHANGED
@@ -8,7 +8,7 @@ CONFIRM_MAPPING_DETAILS_MD = """
8
  <h1 style="text-align: center;">
9
  Confirm Pre-processing Details
10
  </h1>
11
- Make sure the output variable's labels and the input variable's name are accurately mapped across both the dataset and the model.
12
  """
13
  CONFIRM_MAPPING_DETAILS_FAIL_MD = """
14
  <h1 style="text-align: center;">
@@ -38,7 +38,7 @@ PREDICTION_SAMPLE_MD = """
38
 
39
  MAPPING_STYLED_ERROR_WARNING = """
40
  <h3 style="text-align: center;color: orange; background-color: #fff0f3; border-radius: 8px; padding: 10px; ">
41
- ⚠️ We're unable to automatically map the input variable's name and output variable's labels of your dataset with the model's. <b>Please manually check the mapping below.</b>
42
  </h3>
43
  """
44
 
@@ -57,7 +57,7 @@ USE_INFERENCE_API_TIP = """
57
  <a href="https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task">
58
  Hugging Face Inference API
59
  </a>
60
- . Please input your <a href="https://huggingface.co/settings/tokens">Hugging Face token</a> to do so.
61
  """
62
 
63
  HF_TOKEN_INVALID_STYLED= """
@@ -66,10 +66,10 @@ HF_TOKEN_INVALID_STYLED= """
66
  </p>
67
  """
68
 
 
 
 
 
 
69
  def get_dataset_fetch_error_raw(error):
70
  return f"""Sorry you cannot use this dataset because {error}. Contact HF team to support this dataset."""
71
-
72
- def get_styled_input(input):
73
- return f"""<h3 style="text-align: center;color: #4ca154; background-color: #e2fbe8; border-radius: 8px; padding: 10px; ">
74
- Your model and dataset have been validated! <br /> Sample input: {input}
75
- </h3>"""
 
8
  <h1 style="text-align: center;">
9
  Confirm Pre-processing Details
10
  </h1>
11
+ Make sure the output variable's labels and the input variable's name are accurately mapped across both the dataset and the model. You can select the output variable's labels from the dropdowns below.
12
  """
13
  CONFIRM_MAPPING_DETAILS_FAIL_MD = """
14
  <h1 style="text-align: center;">
 
38
 
39
  MAPPING_STYLED_ERROR_WARNING = """
40
  <h3 style="text-align: center;color: orange; background-color: #fff0f3; border-radius: 8px; padding: 10px; ">
41
+ ⚠️ We're unable to automatically map the input variable's name and output variable's labels of your dataset with the model's. Please manually check the mapping below.
42
  </h3>
43
  """
44
 
 
57
  <a href="https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task">
58
  Hugging Face Inference API
59
  </a>
60
+ . Please input your <a href="https://huggingface.co/settings/tokens">Hugging Face token</a> to do so. You can find it <a href="https://huggingface.co/settings/tokens">here</a>.
61
  """
62
 
63
  HF_TOKEN_INVALID_STYLED= """
 
66
  </p>
67
  """
68
 
69
+ VALIDATED_MODEL_DATASET_STYLED = """
70
+ <h3 style="text-align: center;color: #4ca154; background-color: #e2fbe8; border-radius: 8px; padding: 10px; ">
71
+ Your model and dataset have been validated!
72
+ </h3>"""
73
+
74
  def get_dataset_fetch_error_raw(error):
75
  return f"""Sorry you cannot use this dataset because {error}. Contact HF team to support this dataset."""