lewtun HF staff commited on
Commit
55958a6
·
unverified ·
2 Parent(s): efe936d 272762d

Merge pull request #51 from huggingface/add-nli

Browse files
Files changed (2) hide show
  1. app.py +88 -25
  2. utils.py +12 -0
app.py CHANGED
@@ -16,6 +16,7 @@ from utils import (
16
  create_autotrain_project_name,
17
  format_col_mapping,
18
  get_compatible_models,
 
19
  get_dataset_card_url,
20
  get_key,
21
  get_metadata,
@@ -37,6 +38,7 @@ TASK_TO_ID = {
37
  "image_multi_class_classification": 18,
38
  "binary_classification": 1,
39
  "multi_class_classification": 2,
 
40
  "entity_extraction": 4,
41
  "extractive_question_answering": 5,
42
  "translation": 6,
@@ -51,6 +53,7 @@ TASK_TO_DEFAULT_METRICS = {
51
  "recall",
52
  "accuracy",
53
  ],
 
54
  "entity_extraction": ["precision", "recall", "f1", "accuracy"],
55
  "extractive_question_answering": ["f1", "exact_match"],
56
  "translation": ["sacrebleu"],
@@ -72,7 +75,6 @@ AUTOTRAIN_TASK_TO_LANG = {
72
 
73
 
74
  SUPPORTED_TASKS = list(TASK_TO_ID.keys())
75
- UNSUPPORTED_TASKS = []
76
 
77
  # Extracted from utils.get_supported_metrics
78
  # Hardcoded for now due to speed / caching constraints
@@ -118,8 +120,6 @@ SUPPORTED_METRICS = [
118
  "jordyvl/ece",
119
  "lvwerra/ai4code",
120
  "lvwerra/amex",
121
- "lvwerra/test",
122
- "lvwerra/test_metric",
123
  ]
124
 
125
 
@@ -180,10 +180,6 @@ if metadata is None:
180
 
181
  with st.expander("Advanced configuration"):
182
  # Select task
183
- # Hack to filter for unsupported tasks
184
- # TODO(lewtun): remove this once we have SQuAD metrics support
185
- if metadata is not None and metadata[0]["task_id"] in UNSUPPORTED_TASKS:
186
- metadata = None
187
  selected_task = st.selectbox(
188
  "Select a task",
189
  SUPPORTED_TASKS,
@@ -201,6 +197,9 @@ with st.expander("Advanced configuration"):
201
  See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
202
  """,
203
  )
 
 
 
204
 
205
  # Select splits
206
  splits_resp = http_get(
@@ -215,8 +214,8 @@ with st.expander("Advanced configuration"):
215
  if split["config"] == selected_config:
216
  split_names.append(split["split"])
217
 
218
- if metadata is not None:
219
- eval_split = metadata[0]["splits"].get("eval_split", None)
220
  else:
221
  eval_split = None
222
  selected_split = st.selectbox(
@@ -260,16 +259,62 @@ with st.expander("Advanced configuration"):
260
  text_col = st.selectbox(
261
  "This column should contain the text to be classified",
262
  col_names,
263
- index=col_names.index(get_key(metadata[0]["col_mapping"], "text")) if metadata is not None else 0,
 
 
264
  )
265
  target_col = st.selectbox(
266
  "This column should contain the labels associated with the text",
267
  col_names,
268
- index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
 
 
269
  )
270
  col_mapping[text_col] = "text"
271
  col_mapping[target_col] = "target"
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  elif selected_task == "entity_extraction":
274
  with col1:
275
  st.markdown("`tokens` column")
@@ -282,12 +327,16 @@ with st.expander("Advanced configuration"):
282
  tokens_col = st.selectbox(
283
  "This column should contain the array of tokens to be classified",
284
  col_names,
285
- index=col_names.index(get_key(metadata[0]["col_mapping"], "tokens")) if metadata is not None else 0,
 
 
286
  )
287
  tags_col = st.selectbox(
288
  "This column should contain the labels associated with each part of the text",
289
  col_names,
290
- index=col_names.index(get_key(metadata[0]["col_mapping"], "tags")) if metadata is not None else 0,
 
 
291
  )
292
  col_mapping[tokens_col] = "tokens"
293
  col_mapping[tags_col] = "tags"
@@ -304,12 +353,16 @@ with st.expander("Advanced configuration"):
304
  text_col = st.selectbox(
305
  "This column should contain the text to be translated",
306
  col_names,
307
- index=col_names.index(get_key(metadata[0]["col_mapping"], "source")) if metadata is not None else 0,
 
 
308
  )
309
  target_col = st.selectbox(
310
  "This column should contain the target translation",
311
  col_names,
312
- index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
 
 
313
  )
314
  col_mapping[text_col] = "source"
315
  col_mapping[target_col] = "target"
@@ -326,19 +379,23 @@ with st.expander("Advanced configuration"):
326
  text_col = st.selectbox(
327
  "This column should contain the text to be summarized",
328
  col_names,
329
- index=col_names.index(get_key(metadata[0]["col_mapping"], "text")) if metadata is not None else 0,
 
 
330
  )
331
  target_col = st.selectbox(
332
  "This column should contain the target summary",
333
  col_names,
334
- index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
 
 
335
  )
336
  col_mapping[text_col] = "text"
337
  col_mapping[target_col] = "target"
338
 
339
  elif selected_task == "extractive_question_answering":
340
- if metadata is not None:
341
- col_mapping = metadata[0]["col_mapping"]
342
  # Hub YAML parser converts periods to hyphens, so we remap them here
343
  col_mapping = format_col_mapping(col_mapping)
344
  with col1:
@@ -362,22 +419,24 @@ with st.expander("Advanced configuration"):
362
  context_col = st.selectbox(
363
  "This column should contain the question's context",
364
  col_names,
365
- index=col_names.index(get_key(col_mapping, "context")) if metadata is not None else 0,
366
  )
367
  question_col = st.selectbox(
368
  "This column should contain the question to be answered, given the context",
369
  col_names,
370
- index=col_names.index(get_key(col_mapping, "question")) if metadata is not None else 0,
371
  )
372
  answers_text_col = st.selectbox(
373
  "This column should contain example answers to the question, extracted from the context",
374
  col_names,
375
- index=col_names.index(get_key(col_mapping, "answers.text")) if metadata is not None else 0,
376
  )
377
  answers_start_col = st.selectbox(
378
  "This column should contain the indices in the context of the first character of each `answers.text`",
379
  col_names,
380
- index=col_names.index(get_key(col_mapping, "answers.answer_start")) if metadata is not None else 0,
 
 
381
  )
382
  col_mapping[context_col] = "context"
383
  col_mapping[question_col] = "question"
@@ -395,12 +454,16 @@ with st.expander("Advanced configuration"):
395
  image_col = st.selectbox(
396
  "This column should contain the images to be classified",
397
  col_names,
398
- index=col_names.index(get_key(metadata[0]["col_mapping"], "image")) if metadata is not None else 0,
 
 
399
  )
400
  target_col = st.selectbox(
401
  "This column should contain the labels associated with the images",
402
  col_names,
403
- index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
 
 
404
  )
405
  col_mapping[image_col] = "image"
406
  col_mapping[target_col] = "target"
 
16
  create_autotrain_project_name,
17
  format_col_mapping,
18
  get_compatible_models,
19
+ get_config_metadata,
20
  get_dataset_card_url,
21
  get_key,
22
  get_metadata,
 
38
  "image_multi_class_classification": 18,
39
  "binary_classification": 1,
40
  "multi_class_classification": 2,
41
+ "natural_language_inference": 22,
42
  "entity_extraction": 4,
43
  "extractive_question_answering": 5,
44
  "translation": 6,
 
53
  "recall",
54
  "accuracy",
55
  ],
56
+ "natural_language_inference": ["f1", "precision", "recall", "auc", "accuracy"],
57
  "entity_extraction": ["precision", "recall", "f1", "accuracy"],
58
  "extractive_question_answering": ["f1", "exact_match"],
59
  "translation": ["sacrebleu"],
 
75
 
76
 
77
  SUPPORTED_TASKS = list(TASK_TO_ID.keys())
 
78
 
79
  # Extracted from utils.get_supported_metrics
80
  # Hardcoded for now due to speed / caching constraints
 
120
  "jordyvl/ece",
121
  "lvwerra/ai4code",
122
  "lvwerra/amex",
 
 
123
  ]
124
 
125
 
 
180
 
181
  with st.expander("Advanced configuration"):
182
  # Select task
 
 
 
 
183
  selected_task = st.selectbox(
184
  "Select a task",
185
  SUPPORTED_TASKS,
 
197
  See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
198
  """,
199
  )
200
+ # Some datasets have multiple metadata (one per config), so we grab the one associated with the selected config
201
+ config_metadata = get_config_metadata(selected_config, metadata)
202
+ print(f"INFO -- Config metadata: {config_metadata}")
203
 
204
  # Select splits
205
  splits_resp = http_get(
 
214
  if split["config"] == selected_config:
215
  split_names.append(split["split"])
216
 
217
+ if config_metadata is not None:
218
+ eval_split = config_metadata["splits"].get("eval_split", None)
219
  else:
220
  eval_split = None
221
  selected_split = st.selectbox(
 
259
  text_col = st.selectbox(
260
  "This column should contain the text to be classified",
261
  col_names,
262
+ index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
263
+ if config_metadata is not None
264
+ else 0,
265
  )
266
  target_col = st.selectbox(
267
  "This column should contain the labels associated with the text",
268
  col_names,
269
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
270
+ if config_metadata is not None
271
+ else 0,
272
  )
273
  col_mapping[text_col] = "text"
274
  col_mapping[target_col] = "target"
275
 
276
+ if selected_task in ["natural_language_inference"]:
277
+ config_metadata = get_config_metadata(selected_config, metadata)
278
+ with col1:
279
+ st.markdown("`text1` column")
280
+ st.text("")
281
+ st.text("")
282
+ st.text("")
283
+ st.text("")
284
+ st.text("")
285
+ st.markdown("`text2` column")
286
+ st.text("")
287
+ st.text("")
288
+ st.text("")
289
+ st.text("")
290
+ st.text("")
291
+ st.markdown("`target` column")
292
+ with col2:
293
+ text1_col = st.selectbox(
294
+ "This column should contain the first text passage to be classified",
295
+ col_names,
296
+ index=col_names.index(get_key(config_metadata["col_mapping"], "text1"))
297
+ if config_metadata is not None
298
+ else 0,
299
+ )
300
+ text2_col = st.selectbox(
301
+ "This column should contain the second text passage to be classified",
302
+ col_names,
303
+ index=col_names.index(get_key(config_metadata["col_mapping"], "text2"))
304
+ if config_metadata is not None
305
+ else 0,
306
+ )
307
+ target_col = st.selectbox(
308
+ "This column should contain the labels associated with the text",
309
+ col_names,
310
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
311
+ if config_metadata is not None
312
+ else 0,
313
+ )
314
+ col_mapping[text1_col] = "text1"
315
+ col_mapping[text2_col] = "text2"
316
+ col_mapping[target_col] = "target"
317
+
318
  elif selected_task == "entity_extraction":
319
  with col1:
320
  st.markdown("`tokens` column")
 
327
  tokens_col = st.selectbox(
328
  "This column should contain the array of tokens to be classified",
329
  col_names,
330
+ index=col_names.index(get_key(config_metadata["col_mapping"], "tokens"))
331
+ if config_metadata is not None
332
+ else 0,
333
  )
334
  tags_col = st.selectbox(
335
  "This column should contain the labels associated with each part of the text",
336
  col_names,
337
+ index=col_names.index(get_key(config_metadata["col_mapping"], "tags"))
338
+ if config_metadata is not None
339
+ else 0,
340
  )
341
  col_mapping[tokens_col] = "tokens"
342
  col_mapping[tags_col] = "tags"
 
353
  text_col = st.selectbox(
354
  "This column should contain the text to be translated",
355
  col_names,
356
+ index=col_names.index(get_key(config_metadata["col_mapping"], "source"))
357
+ if config_metadata is not None
358
+ else 0,
359
  )
360
  target_col = st.selectbox(
361
  "This column should contain the target translation",
362
  col_names,
363
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
364
+ if config_metadata is not None
365
+ else 0,
366
  )
367
  col_mapping[text_col] = "source"
368
  col_mapping[target_col] = "target"
 
379
  text_col = st.selectbox(
380
  "This column should contain the text to be summarized",
381
  col_names,
382
+ index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
383
+ if config_metadata is not None
384
+ else 0,
385
  )
386
  target_col = st.selectbox(
387
  "This column should contain the target summary",
388
  col_names,
389
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
390
+ if config_metadata is not None
391
+ else 0,
392
  )
393
  col_mapping[text_col] = "text"
394
  col_mapping[target_col] = "target"
395
 
396
  elif selected_task == "extractive_question_answering":
397
+ if config_metadata is not None:
398
+ col_mapping = config_metadata["col_mapping"]
399
  # Hub YAML parser converts periods to hyphens, so we remap them here
400
  col_mapping = format_col_mapping(col_mapping)
401
  with col1:
 
419
  context_col = st.selectbox(
420
  "This column should contain the question's context",
421
  col_names,
422
+ index=col_names.index(get_key(col_mapping, "context")) if config_metadata is not None else 0,
423
  )
424
  question_col = st.selectbox(
425
  "This column should contain the question to be answered, given the context",
426
  col_names,
427
+ index=col_names.index(get_key(col_mapping, "question")) if config_metadata is not None else 0,
428
  )
429
  answers_text_col = st.selectbox(
430
  "This column should contain example answers to the question, extracted from the context",
431
  col_names,
432
+ index=col_names.index(get_key(col_mapping, "answers.text")) if config_metadata is not None else 0,
433
  )
434
  answers_start_col = st.selectbox(
435
  "This column should contain the indices in the context of the first character of each `answers.text`",
436
  col_names,
437
+ index=col_names.index(get_key(col_mapping, "answers.answer_start"))
438
+ if config_metadata is not None
439
+ else 0,
440
  )
441
  col_mapping[context_col] = "context"
442
  col_mapping[question_col] = "question"
 
454
  image_col = st.selectbox(
455
  "This column should contain the images to be classified",
456
  col_names,
457
+ index=col_names.index(get_key(config_metadata["col_mapping"], "image"))
458
+ if config_metadata is not None
459
+ else 0,
460
  )
461
  target_col = st.selectbox(
462
  "This column should contain the labels associated with the images",
463
  col_names,
464
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
465
+ if config_metadata is not None
466
+ else 0,
467
  )
468
  col_mapping[image_col] = "image"
469
  col_mapping[target_col] = "target"
utils.py CHANGED
@@ -12,6 +12,7 @@ from tqdm import tqdm
12
  AUTOTRAIN_TASK_TO_HUB_TASK = {
13
  "binary_classification": "text-classification",
14
  "multi_class_classification": "text-classification",
 
15
  "entity_extraction": "token-classification",
16
  "extractive_question_answering": "question-answering",
17
  "translation": "translation",
@@ -197,3 +198,14 @@ def create_autotrain_project_name(dataset_id: str) -> str:
197
  # Project names need to be unique, so we append a random string to guarantee this
198
  project_id = str(uuid.uuid4())[:8]
199
  return f"eval-project-{dataset_id_formatted}-{project_id}"
 
 
 
 
 
 
 
 
 
 
 
 
12
  AUTOTRAIN_TASK_TO_HUB_TASK = {
13
  "binary_classification": "text-classification",
14
  "multi_class_classification": "text-classification",
15
+ "natural_language_inference": "text-classification",
16
  "entity_extraction": "token-classification",
17
  "extractive_question_answering": "question-answering",
18
  "translation": "translation",
 
198
  # Project names need to be unique, so we append a random string to guarantee this
199
  project_id = str(uuid.uuid4())[:8]
200
  return f"eval-project-{dataset_id_formatted}-{project_id}"
201
+
202
+
203
+ def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
204
+ """Gets the dataset card metadata for the given config."""
205
+ if metadata is None:
206
+ return None
207
+ config_metadata = [m for m in metadata if m["config"] == config]
208
+ if len(config_metadata) >= 1:
209
+ return config_metadata[0]
210
+ else:
211
+ return None