Upload folder using huggingface_hub
Browse files- api.py +9 -2
- metric_utils.py +5 -1
- metrics.py +435 -19
- operators.py +34 -18
- serializers.py +20 -1
- task.py +6 -1
- templates.py +91 -10
- types.py +9 -0
- version.py +1 -1
api.py
CHANGED
@@ -7,6 +7,7 @@ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
7 |
from .artifact import fetch_artifact
|
8 |
from .card import TaskCard
|
9 |
from .dataset_utils import get_dataset_artifact
|
|
|
10 |
from .inference import (
|
11 |
InferenceEngine,
|
12 |
LogProbInferenceEngine,
|
@@ -198,8 +199,14 @@ def load_dataset(
|
|
198 |
).with_transform(loads_instance)
|
199 |
|
200 |
|
201 |
-
def evaluate(
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
|
205 |
def post_process(predictions, data) -> List[Dict[str, Any]]:
|
|
|
7 |
from .artifact import fetch_artifact
|
8 |
from .card import TaskCard
|
9 |
from .dataset_utils import get_dataset_artifact
|
10 |
+
from .error_utils import UnitxtError
|
11 |
from .inference import (
|
12 |
InferenceEngine,
|
13 |
LogProbInferenceEngine,
|
|
|
199 |
).with_transform(loads_instance)
|
200 |
|
201 |
|
202 |
+
def evaluate(
|
203 |
+
predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
|
204 |
+
) -> EvaluationResults:
|
205 |
+
if dataset is None and data is None:
|
206 |
+
raise UnitxtError(message="Specify 'dataset' in evaluate")
|
207 |
+
if data is not None:
|
208 |
+
dataset = data # for backward compatibility
|
209 |
+
return _compute(predictions=predictions, references=dataset)
|
210 |
|
211 |
|
212 |
def post_process(predictions, data) -> List[Dict[str, Any]]:
|
metric_utils.py
CHANGED
@@ -38,7 +38,11 @@ constants = get_constants()
|
|
38 |
|
39 |
|
40 |
def nan_mean(scores):
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
|
43 |
|
44 |
class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
|
|
38 |
|
39 |
|
40 |
def nan_mean(scores):
|
41 |
+
result = mean(score for score in scores if score == score)
|
42 |
+
try:
|
43 |
+
return float(result)
|
44 |
+
except:
|
45 |
+
return result
|
46 |
|
47 |
|
48 |
class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
metrics.py
CHANGED
@@ -7,10 +7,10 @@ import string
|
|
7 |
import uuid
|
8 |
import warnings
|
9 |
from abc import ABC, abstractmethod
|
10 |
-
from collections import Counter, defaultdict
|
11 |
from dataclasses import field
|
12 |
from functools import lru_cache
|
13 |
-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
14 |
|
15 |
import numpy
|
16 |
import numpy as np
|
@@ -317,6 +317,398 @@ class Metric(Artifact):
|
|
317 |
instance["score"]["global"].pop(score_ci)
|
318 |
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
class MetricWithConfidenceInterval(Metric):
|
321 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
322 |
# Use None to disable confidence interval computation.
|
@@ -539,10 +931,10 @@ class MetricWithConfidenceInterval(Metric):
|
|
539 |
confidence_level=self.confidence_level,
|
540 |
random_state=random_gen,
|
541 |
).confidence_interval
|
542 |
-
result["score_ci_low"] = ci.low
|
543 |
-
result["score_ci_high"] = ci.high
|
544 |
-
result[f"{score_name}_ci_low"] = ci.low
|
545 |
-
result[f"{score_name}_ci_high"] = ci.high
|
546 |
return result
|
547 |
|
548 |
|
@@ -1732,7 +2124,7 @@ class HuggingfaceMetric(GlobalMetric):
|
|
1732 |
**self.hf_compute_args,
|
1733 |
)
|
1734 |
if self.hf_main_score:
|
1735 |
-
result[self.main_score] = result[self.hf_main_score]
|
1736 |
del result[self.hf_main_score]
|
1737 |
if self.scale != 1.0:
|
1738 |
assert (
|
@@ -1752,6 +2144,8 @@ class HuggingfaceMetric(GlobalMetric):
|
|
1752 |
result[key], float
|
1753 |
), "Scaled field '{key}' is not float: {result[key]}"
|
1754 |
result[key] /= self.scale
|
|
|
|
|
1755 |
return result
|
1756 |
|
1757 |
|
@@ -1837,17 +2231,49 @@ class HuggingfaceInstanceMetric(InstanceMetric):
|
|
1837 |
return score
|
1838 |
|
1839 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1840 |
class Meteor(InstanceMetric):
|
1841 |
main_score = "meteor"
|
1842 |
ci_scores = ["meteor"]
|
1843 |
reduction_map = {"mean": ["meteor"]}
|
1844 |
prediction_type = str
|
1845 |
|
1846 |
-
_requirements_list: List[str] = ["nltk"]
|
1847 |
alpha: float = 0.9
|
1848 |
beta: int = 3
|
1849 |
gamma: float = 0.5
|
1850 |
-
# unitxt uses nltk version >= 3.8
|
1851 |
|
1852 |
def prepare(self):
|
1853 |
super().prepare()
|
@@ -1861,16 +2287,6 @@ class Meteor(InstanceMetric):
|
|
1861 |
self.word_tokenize = word_tokenize
|
1862 |
self.meteor_score = meteor_score
|
1863 |
|
1864 |
-
def verify(self):
|
1865 |
-
import importlib.metadata as importlib_metadata
|
1866 |
-
|
1867 |
-
from datasets.config import version
|
1868 |
-
|
1869 |
-
nltk_version = version.parse(importlib_metadata.version("nltk"))
|
1870 |
-
assert nltk_version >= version.Version(
|
1871 |
-
"3.6.6"
|
1872 |
-
), "nltk version must be at least 3.6.6"
|
1873 |
-
|
1874 |
def compute(self, references, prediction, task_data):
|
1875 |
score = self.meteor_score.meteor_score(
|
1876 |
[self.word_tokenize(ref) for ref in references],
|
|
|
7 |
import uuid
|
8 |
import warnings
|
9 |
from abc import ABC, abstractmethod
|
10 |
+
from collections import Counter, defaultdict, namedtuple
|
11 |
from dataclasses import field
|
12 |
from functools import lru_cache
|
13 |
+
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union
|
14 |
|
15 |
import numpy
|
16 |
import numpy as np
|
|
|
317 |
instance["score"]["global"].pop(score_ci)
|
318 |
|
319 |
|
320 |
+
def new_random_generator():
|
321 |
+
# The np.random.default_rng expects a 32-bit int, while hash(..) can return a 64-bit integer.
|
322 |
+
# So use '& MAX_32BIT' to get a 32-bit seed.
|
323 |
+
_max_32bit = 2**32 - 1
|
324 |
+
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
325 |
+
|
326 |
+
|
327 |
+
class ConfidenceIntervalMixin(Artifact):
|
328 |
+
n_resamples: int = 1000
|
329 |
+
confidence_level: float = 0.95
|
330 |
+
ci_score_names: List[str] = None
|
331 |
+
|
332 |
+
@abstractmethod
|
333 |
+
def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
|
334 |
+
pass
|
335 |
+
|
336 |
+
def get_statistic(self, data: List[Any], score_names: List[str]):
|
337 |
+
def statistic_function(indices, axis=0):
|
338 |
+
# indices might be a 1D or 2D array, depending on bootstrap internals
|
339 |
+
# For simplicity, ensure we handle them as 1D.
|
340 |
+
indices = np.atleast_1d(indices).astype(int)
|
341 |
+
|
342 |
+
# Gather the subset
|
343 |
+
sample = [data[i] for i in indices]
|
344 |
+
|
345 |
+
# Compute metrics on this sample
|
346 |
+
scores = self._sample_to_scores(sample)
|
347 |
+
|
348 |
+
# Return them in consistent order
|
349 |
+
return np.array([scores[m] for m in score_names])
|
350 |
+
|
351 |
+
return statistic_function
|
352 |
+
|
353 |
+
def bootstrap(self, data: List[Any], score_names: List[str]):
|
354 |
+
if self.ci_score_names is not None:
|
355 |
+
score_names = self.ci_score_names
|
356 |
+
|
357 |
+
intervals = bootstrap(
|
358 |
+
(np.arange(len(data)),),
|
359 |
+
statistic=self.get_statistic(data, score_names),
|
360 |
+
n_resamples=self.n_resamples,
|
361 |
+
confidence_level=self.confidence_level,
|
362 |
+
random_state=new_random_generator(),
|
363 |
+
paired=False,
|
364 |
+
vectorized=False, # set to True if your statistic function is vectorized
|
365 |
+
method="BCa",
|
366 |
+
).confidence_interval
|
367 |
+
|
368 |
+
result = {}
|
369 |
+
for i, metric in enumerate(score_names):
|
370 |
+
result[f"{metric}_ci_low"] = float(intervals.low[i])
|
371 |
+
result[f"{metric}_ci_high"] = float(intervals.high[i])
|
372 |
+
|
373 |
+
return result
|
374 |
+
|
375 |
+
|
376 |
+
from typing import Generic, TypeVar, NamedTuple
|
377 |
+
from dataclasses import dataclass
|
378 |
+
|
379 |
+
IntermediateType = TypeVar("IntermediateType")
|
380 |
+
PredictionType = TypeVar("PredictionType")
|
381 |
+
|
382 |
+
|
383 |
+
class EvaluationInput(tuple, Generic[PredictionType]):
|
384 |
+
def __new__(
|
385 |
+
cls,
|
386 |
+
prediction: PredictionType,
|
387 |
+
references: List[PredictionType],
|
388 |
+
task_data: Dict[str, Any],
|
389 |
+
) -> "EvaluationInput[PredictionType]":
|
390 |
+
return super().__new__(cls, (prediction, references, task_data))
|
391 |
+
|
392 |
+
|
393 |
+
def is_original_key(key):
|
394 |
+
if (
|
395 |
+
key.endswith("_ci_low")
|
396 |
+
or key.endswith("_ci_high")
|
397 |
+
or key == "score"
|
398 |
+
or key == "num_of_instances"
|
399 |
+
or key == "score_name"
|
400 |
+
):
|
401 |
+
return False
|
402 |
+
return True
|
403 |
+
|
404 |
+
|
405 |
+
class MapReduceMetric(
|
406 |
+
StreamOperator,
|
407 |
+
Metric,
|
408 |
+
ConfidenceIntervalMixin,
|
409 |
+
Generic[PredictionType, IntermediateType],
|
410 |
+
):
|
411 |
+
score_prefix = ""
|
412 |
+
reference_field: str = NonPositionalField(default="references")
|
413 |
+
prediction_field: str = NonPositionalField(default="prediction")
|
414 |
+
|
415 |
+
def map(
|
416 |
+
self,
|
417 |
+
prediction: PredictionType,
|
418 |
+
references: List[PredictionType],
|
419 |
+
task_data: Dict[str, Any],
|
420 |
+
) -> IntermediateType:
|
421 |
+
raise NotImplementedError()
|
422 |
+
|
423 |
+
def reduce_one(self, intermidate: IntermediateType):
|
424 |
+
return self.reduce([intermidate])
|
425 |
+
|
426 |
+
@abstractmethod
|
427 |
+
def reduce(self, intermediates: List[IntermediateType]) -> Dict[str, Any]:
|
428 |
+
return {}
|
429 |
+
|
430 |
+
def disable_confidence_interval_calculation(self):
|
431 |
+
self.n_resamples = None
|
432 |
+
|
433 |
+
def annotate_scores(self, scores):
|
434 |
+
scores = {
|
435 |
+
**{self.score_prefix + key: val for key, val in scores.items()},
|
436 |
+
"score_name": self.score_prefix + self.main_score,
|
437 |
+
"score": scores[self.main_score],
|
438 |
+
}
|
439 |
+
for level in ["high", "low"]:
|
440 |
+
if f"{self.main_score}_ci_{level}" in scores:
|
441 |
+
scores[f"score_ci_{level}"] = scores[f"{self.main_score}_ci_{level}"]
|
442 |
+
return scores
|
443 |
+
|
444 |
+
def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
|
445 |
+
return self.reduce(sample)
|
446 |
+
|
447 |
+
def reduce_and_bootstrap(
|
448 |
+
self, intermediates: List[IntermediateType]
|
449 |
+
) -> Dict[str, Any]:
|
450 |
+
scores = self.reduce(intermediates)
|
451 |
+
score_names = [k for k, v in scores.items() if isinstance(v, float)]
|
452 |
+
if self.n_resamples is None:
|
453 |
+
return scores
|
454 |
+
intervals = self.bootstrap(intermediates, score_names)
|
455 |
+
return {**scores, **intervals}
|
456 |
+
|
457 |
+
def _instance_to_evaluation_input(
|
458 |
+
self, instance: Dict[str, Any]
|
459 |
+
) -> EvaluationInput[PredictionType]:
|
460 |
+
instance = self.verify_instance(instance)
|
461 |
+
|
462 |
+
task_data = instance.get("task_data", {})
|
463 |
+
|
464 |
+
if self.reference_field == "references":
|
465 |
+
references = instance["references"]
|
466 |
+
else:
|
467 |
+
references = task_data[self.reference_field]
|
468 |
+
if not isinstance(references, list):
|
469 |
+
references = [references]
|
470 |
+
if self.prediction_field == "prediction":
|
471 |
+
prediction = instance["prediction"]
|
472 |
+
else:
|
473 |
+
prediction = task_data[self.prediction_field]
|
474 |
+
|
475 |
+
self._validate_prediction(prediction)
|
476 |
+
self._validate_reference(references)
|
477 |
+
|
478 |
+
return EvaluationInput[PredictionType](
|
479 |
+
prediction=prediction, references=references, task_data=task_data
|
480 |
+
)
|
481 |
+
|
482 |
+
def _instances_stream_to_evaluation_inputs(
|
483 |
+
self, stream: Stream
|
484 |
+
) -> Generator[EvaluationInput[PredictionType], None, None]:
|
485 |
+
for instance in stream:
|
486 |
+
yield self._instance_to_evaluation_input(instance)
|
487 |
+
|
488 |
+
def map_stream(
|
489 |
+
self,
|
490 |
+
evaluation_inputs_stream: Generator[
|
491 |
+
EvaluationInput[PredictionType], None, None
|
492 |
+
],
|
493 |
+
):
|
494 |
+
intermediates = []
|
495 |
+
for prediction, references, task_data in evaluation_inputs_stream:
|
496 |
+
intermediate = self.map(
|
497 |
+
prediction=prediction, references=references, task_data=task_data
|
498 |
+
)
|
499 |
+
|
500 |
+
intermediates.append(intermediate)
|
501 |
+
return intermediates
|
502 |
+
|
503 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None):
|
504 |
+
instances_scores, global_scores = self.compute(stream, stream_name)
|
505 |
+
for i, (instance, instance_scores) in enumerate(zip(stream, instances_scores)):
|
506 |
+
previous_score = instance.get("score", {"global": {}, "instance": {}})
|
507 |
+
|
508 |
+
if i == 0:
|
509 |
+
for key in global_scores:
|
510 |
+
if is_original_key(key) and key in previous_score["global"]:
|
511 |
+
UnitxtWarning(
|
512 |
+
message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded "
|
513 |
+
f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. "
|
514 |
+
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
|
515 |
+
f"which will yield, in this case, a score named: 'my_second_{key}')",
|
516 |
+
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
517 |
+
)
|
518 |
+
|
519 |
+
global_scores = {**previous_score["global"], **global_scores}
|
520 |
+
instance_scores = {**previous_score["instance"], **instance_scores}
|
521 |
+
|
522 |
+
yield {
|
523 |
+
**instance,
|
524 |
+
"score": {"global": global_scores, "instance": instance_scores},
|
525 |
+
}
|
526 |
+
|
527 |
+
def compute(self, stream: Stream, stream_name: Optional[str] = None):
|
528 |
+
evaluation_inputs_stream = self._instances_stream_to_evaluation_inputs(stream)
|
529 |
+
intermediates_list = self.map_stream(evaluation_inputs_stream)
|
530 |
+
|
531 |
+
instances_scores = []
|
532 |
+
for intermediate in intermediates_list:
|
533 |
+
instance_score = self.reduce_one(intermediate)
|
534 |
+
instance_score = self.annotate_scores(instance_score)
|
535 |
+
instances_scores.append(instance_score)
|
536 |
+
|
537 |
+
global_scores = self.reduce_and_bootstrap(intermediates_list)
|
538 |
+
global_scores = self.annotate_scores(global_scores)
|
539 |
+
|
540 |
+
global_scores["num_of_instances"] = len(intermediates_list)
|
541 |
+
|
542 |
+
return instances_scores, global_scores
|
543 |
+
|
544 |
+
|
545 |
+
def get_index_or_default(lst, item, default=-1):
|
546 |
+
try:
|
547 |
+
return lst.index(item)
|
548 |
+
except ValueError:
|
549 |
+
return default
|
550 |
+
|
551 |
+
|
552 |
+
class AggregationReduction(Artifact, Generic[IntermediateType]):
|
553 |
+
def reduce(self, intermidates: List[IntermediateType]) -> Dict[str, Any]:
|
554 |
+
pass
|
555 |
+
|
556 |
+
|
557 |
+
class DictReduction(AggregationReduction[Dict[str, float]]):
|
558 |
+
def reduce_list(self, lst: List[float]):
|
559 |
+
pass
|
560 |
+
|
561 |
+
def reduce(self, intermidates: List[Dict[str, float]]):
|
562 |
+
lists = {}
|
563 |
+
for intermidate in intermidates:
|
564 |
+
for key, val in intermidate.items():
|
565 |
+
if key not in lists:
|
566 |
+
lists[key] = []
|
567 |
+
lists[key].append(val)
|
568 |
+
|
569 |
+
result = {}
|
570 |
+
for key, val_list in lists.items():
|
571 |
+
result[key] = self.reduce_list(val_list)
|
572 |
+
return result
|
573 |
+
|
574 |
+
|
575 |
+
class MeanReduction(DictReduction):
|
576 |
+
def reduce_list(self, lst: List[float]):
|
577 |
+
return nan_mean(lst)
|
578 |
+
|
579 |
+
|
580 |
+
class MaxReduction(DictReduction):
|
581 |
+
def reduce_list(self, lst: List[float]):
|
582 |
+
return float(nan_max(lst))
|
583 |
+
|
584 |
+
|
585 |
+
class ReductionInstanceMetric(
|
586 |
+
MapReduceMetric[PredictionType, IntermediateType],
|
587 |
+
Generic[PredictionType, IntermediateType],
|
588 |
+
):
|
589 |
+
reduction: AggregationReduction[IntermediateType]
|
590 |
+
|
591 |
+
def reduce(self, intermediates: List[IntermediateType]) -> Dict[str, Any]:
|
592 |
+
return self.reduction.reduce(intermediates)
|
593 |
+
|
594 |
+
def reduce_one(self, intermidate: IntermediateType):
|
595 |
+
return recursive_copy(intermidate)
|
596 |
+
|
597 |
+
|
598 |
+
class AccuracyFast(ReductionInstanceMetric[str, Dict[str, float]]):
|
599 |
+
main_score = "accuracy"
|
600 |
+
reduction = MeanReduction()
|
601 |
+
|
602 |
+
def map(
|
603 |
+
self, prediction: str, references: List[str], task_data: Dict[str, Any]
|
604 |
+
) -> Dict[str, float]:
|
605 |
+
return {
|
606 |
+
self.main_score: float(
|
607 |
+
str(prediction) in [str(reference) for reference in references]
|
608 |
+
)
|
609 |
+
}
|
610 |
+
|
611 |
+
|
612 |
+
class F1Fast(MapReduceMetric[str, Tuple[int, int]]):
|
613 |
+
main_score = "f1"
|
614 |
+
averages: List[Literal["f1", "macro", "micro", "per_class"]] = [
|
615 |
+
"f1",
|
616 |
+
"micro",
|
617 |
+
"macro",
|
618 |
+
"per_class",
|
619 |
+
]
|
620 |
+
ignore_punc: bool = True
|
621 |
+
ignore_case: bool = True
|
622 |
+
_requirements_list = ["scikit-learn", "regex"]
|
623 |
+
|
624 |
+
def prepare(self):
|
625 |
+
super().prepare()
|
626 |
+
from sklearn.metrics import f1_score
|
627 |
+
|
628 |
+
self._metric = f1_score
|
629 |
+
import regex
|
630 |
+
from functools import partial
|
631 |
+
|
632 |
+
self.remove_punc = partial(regex.compile(r"\p{P}+").sub, "")
|
633 |
+
|
634 |
+
def get_str_id(self, str):
|
635 |
+
if str not in self.str_to_id:
|
636 |
+
id = len(self.str_to_id)
|
637 |
+
self.str_to_id[str] = id
|
638 |
+
self.id_to_str[id] = str
|
639 |
+
return self.str_to_id[str]
|
640 |
+
|
641 |
+
def map_stream(
|
642 |
+
self, evaluation_inputs_stream: Generator[EvaluationInput[str], None, None]
|
643 |
+
):
|
644 |
+
self.str_to_id = {}
|
645 |
+
self.id_to_str = {}
|
646 |
+
return super().map_stream(evaluation_inputs_stream)
|
647 |
+
|
648 |
+
def map(
|
649 |
+
self, prediction: str, references: List[str], task_data: Dict[str, Any]
|
650 |
+
) -> Tuple[int, int]:
|
651 |
+
reference_index = self.get_str_id(references[0])
|
652 |
+
prediction_index = self.get_str_id(prediction)
|
653 |
+
|
654 |
+
return prediction_index, reference_index
|
655 |
+
|
656 |
+
def reduce(self, intermediates: List[Tuple[int, int]]) -> Dict[str, Any]:
|
657 |
+
y_true = []
|
658 |
+
y_pred = []
|
659 |
+
labels = set()
|
660 |
+
for pred_idx, ref_idx in intermediates:
|
661 |
+
y_pred.append(pred_idx)
|
662 |
+
y_true.append(ref_idx)
|
663 |
+
labels.add(ref_idx)
|
664 |
+
|
665 |
+
labels = list(labels)
|
666 |
+
result = {}
|
667 |
+
|
668 |
+
if "f1" in self.averages:
|
669 |
+
result["f1"] = float(
|
670 |
+
self._metric(
|
671 |
+
y_true,
|
672 |
+
y_pred,
|
673 |
+
average="macro",
|
674 |
+
labels=labels,
|
675 |
+
zero_division=0,
|
676 |
+
)
|
677 |
+
)
|
678 |
+
|
679 |
+
if "micro" in self.averages:
|
680 |
+
result["f1_micro"] = float(
|
681 |
+
self._metric(
|
682 |
+
y_true,
|
683 |
+
y_pred,
|
684 |
+
average="micro",
|
685 |
+
labels=labels,
|
686 |
+
zero_division=0,
|
687 |
+
)
|
688 |
+
)
|
689 |
+
|
690 |
+
if "macro" in self.averages:
|
691 |
+
result["f1_macro"] = float(
|
692 |
+
self._metric(
|
693 |
+
y_true,
|
694 |
+
y_pred,
|
695 |
+
average="macro",
|
696 |
+
labels=labels,
|
697 |
+
zero_division=0,
|
698 |
+
)
|
699 |
+
)
|
700 |
+
|
701 |
+
if "per_class" in self.averages:
|
702 |
+
f1_per_class = self._metric(
|
703 |
+
y_true, y_pred, average=None, labels=list(labels), zero_division=0
|
704 |
+
)
|
705 |
+
for label, score in zip(labels, f1_per_class):
|
706 |
+
class_name = self.id_to_str[label]
|
707 |
+
result[f"f1_{class_name}"] = float(score)
|
708 |
+
|
709 |
+
return result
|
710 |
+
|
711 |
+
|
712 |
class MetricWithConfidenceInterval(Metric):
|
713 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
714 |
# Use None to disable confidence interval computation.
|
|
|
931 |
confidence_level=self.confidence_level,
|
932 |
random_state=random_gen,
|
933 |
).confidence_interval
|
934 |
+
result["score_ci_low"] = float(ci.low)
|
935 |
+
result["score_ci_high"] = float(ci.high)
|
936 |
+
result[f"{score_name}_ci_low"] = float(ci.low)
|
937 |
+
result[f"{score_name}_ci_high"] = float(ci.high)
|
938 |
return result
|
939 |
|
940 |
|
|
|
2124 |
**self.hf_compute_args,
|
2125 |
)
|
2126 |
if self.hf_main_score:
|
2127 |
+
result[self.main_score] = float(result[self.hf_main_score])
|
2128 |
del result[self.hf_main_score]
|
2129 |
if self.scale != 1.0:
|
2130 |
assert (
|
|
|
2144 |
result[key], float
|
2145 |
), "Scaled field '{key}' is not float: {result[key]}"
|
2146 |
result[key] /= self.scale
|
2147 |
+
if self.main_score in result:
|
2148 |
+
result[self.main_score] = float(result[self.main_score])
|
2149 |
return result
|
2150 |
|
2151 |
|
|
|
2231 |
return score
|
2232 |
|
2233 |
|
2234 |
+
class MeteorFast(ReductionInstanceMetric[str, Dict[str, float]]):
|
2235 |
+
main_score = "meteor"
|
2236 |
+
reduction = MeanReduction()
|
2237 |
+
_requirements_list: List[str] = ["nltk>=3.6.6"]
|
2238 |
+
alpha: float = 0.9
|
2239 |
+
beta: int = 3
|
2240 |
+
gamma: float = 0.5
|
2241 |
+
|
2242 |
+
def prepare(self):
|
2243 |
+
super().prepare()
|
2244 |
+
import nltk
|
2245 |
+
|
2246 |
+
nltk.download("wordnet", quiet=True)
|
2247 |
+
nltk.download("omw-1.4", quiet=True)
|
2248 |
+
from nltk import word_tokenize
|
2249 |
+
from nltk.translate import meteor_score
|
2250 |
+
|
2251 |
+
self.word_tokenize = word_tokenize
|
2252 |
+
self.meteor_score = meteor_score
|
2253 |
+
|
2254 |
+
def map(
|
2255 |
+
self, prediction: str, references: List[str], task_data: Dict[str, Any]
|
2256 |
+
) -> Dict[str, float]:
|
2257 |
+
score = self.meteor_score.meteor_score(
|
2258 |
+
[self.word_tokenize(ref) for ref in references],
|
2259 |
+
self.word_tokenize(prediction),
|
2260 |
+
alpha=self.alpha,
|
2261 |
+
beta=self.beta,
|
2262 |
+
gamma=self.gamma,
|
2263 |
+
)
|
2264 |
+
return {self.main_score: score}
|
2265 |
+
|
2266 |
+
|
2267 |
class Meteor(InstanceMetric):
|
2268 |
main_score = "meteor"
|
2269 |
ci_scores = ["meteor"]
|
2270 |
reduction_map = {"mean": ["meteor"]}
|
2271 |
prediction_type = str
|
2272 |
|
2273 |
+
_requirements_list: List[str] = ["nltk>=3.6.6"]
|
2274 |
alpha: float = 0.9
|
2275 |
beta: int = 3
|
2276 |
gamma: float = 0.5
|
|
|
2277 |
|
2278 |
def prepare(self):
|
2279 |
super().prepare()
|
|
|
2287 |
self.word_tokenize = word_tokenize
|
2288 |
self.meteor_score = meteor_score
|
2289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2290 |
def compute(self, references, prediction, task_data):
|
2291 |
score = self.meteor_score.meteor_score(
|
2292 |
[self.word_tokenize(ref) for ref in references],
|
operators.py
CHANGED
@@ -55,6 +55,7 @@ from typing import (
|
|
55 |
Generator,
|
56 |
Iterable,
|
57 |
List,
|
|
|
58 |
Optional,
|
59 |
Tuple,
|
60 |
Union,
|
@@ -1633,6 +1634,12 @@ class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
|
|
1633 |
yield from stream
|
1634 |
|
1635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1636 |
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
1637 |
"""Applies metric operators to a stream based on a metric field specified in each instance.
|
1638 |
|
@@ -1647,13 +1654,6 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1647 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1648 |
from .metrics import Metric, MetricsList
|
1649 |
|
1650 |
-
def update_scores_of_stream_instances(
|
1651 |
-
stream: Stream, scores: List[dict]
|
1652 |
-
) -> Generator:
|
1653 |
-
for instance, score in zip(stream, scores):
|
1654 |
-
instance["score"] = recursive_copy(score)
|
1655 |
-
yield instance
|
1656 |
-
|
1657 |
# to be populated only when two or more metrics
|
1658 |
accumulated_scores = []
|
1659 |
|
@@ -1680,29 +1680,28 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1680 |
f"Operator {metric_name} must be a Metric or MetricsList"
|
1681 |
)
|
1682 |
|
|
|
|
|
|
|
1683 |
# Each metric operator computes its score and then sets the main score, overwriting
|
1684 |
# the previous main score value (if any). So, we need to reverse the order of the listed metrics.
|
1685 |
# This will cause the first listed metric to run last, and the main score will be set
|
1686 |
# by the first listed metric (as desired).
|
1687 |
metrics_list = list(reversed(metrics_list))
|
1688 |
|
1689 |
-
for
|
1690 |
-
if
|
1691 |
-
|
1692 |
-
|
1693 |
-
if metric_no > 0:
|
1694 |
-
# update input stream with accumulated scores
|
1695 |
reusable_generator = ReusableGenerator(
|
1696 |
generator=update_scores_of_stream_instances,
|
1697 |
gen_kwargs={"stream": stream, "scores": accumulated_scores},
|
1698 |
)
|
1699 |
multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
|
1700 |
-
|
1701 |
-
multi_stream = MultiStream.from_iterables({"tmp": stream})
|
1702 |
multi_stream = metric(multi_stream)
|
1703 |
-
|
1704 |
-
|
1705 |
-
# updating accumulated_scores
|
1706 |
accumulated_scores = []
|
1707 |
for inst in multi_stream["tmp"]:
|
1708 |
accumulated_scores.append(recursive_copy(inst["score"]))
|
@@ -2214,3 +2213,20 @@ class CollateInstances(StreamOperator):
|
|
2214 |
f"batch_size must be an integer equal to or greater than 1. "
|
2215 |
f"Got: {self.batch_size}."
|
2216 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
Generator,
|
56 |
Iterable,
|
57 |
List,
|
58 |
+
Literal,
|
59 |
Optional,
|
60 |
Tuple,
|
61 |
Union,
|
|
|
1634 |
yield from stream
|
1635 |
|
1636 |
|
1637 |
+
def update_scores_of_stream_instances(stream: Stream, scores: List[dict]) -> Generator:
|
1638 |
+
for instance, score in zip(stream, scores):
|
1639 |
+
instance["score"] = recursive_copy(score)
|
1640 |
+
yield instance
|
1641 |
+
|
1642 |
+
|
1643 |
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
1644 |
"""Applies metric operators to a stream based on a metric field specified in each instance.
|
1645 |
|
|
|
1654 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1655 |
from .metrics import Metric, MetricsList
|
1656 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1657 |
# to be populated only when two or more metrics
|
1658 |
accumulated_scores = []
|
1659 |
|
|
|
1680 |
f"Operator {metric_name} must be a Metric or MetricsList"
|
1681 |
)
|
1682 |
|
1683 |
+
for metric in metrics_list:
|
1684 |
+
if not self.calc_confidence_intervals:
|
1685 |
+
metric.disable_confidence_interval_calculation()
|
1686 |
# Each metric operator computes its score and then sets the main score, overwriting
|
1687 |
# the previous main score value (if any). So, we need to reverse the order of the listed metrics.
|
1688 |
# This will cause the first listed metric to run last, and the main score will be set
|
1689 |
# by the first listed metric (as desired).
|
1690 |
metrics_list = list(reversed(metrics_list))
|
1691 |
|
1692 |
+
for i, metric in enumerate(metrics_list):
|
1693 |
+
if i == 0: # first metric
|
1694 |
+
multi_stream = MultiStream({"tmp": stream})
|
1695 |
+
else: # metrics with previous scores
|
|
|
|
|
1696 |
reusable_generator = ReusableGenerator(
|
1697 |
generator=update_scores_of_stream_instances,
|
1698 |
gen_kwargs={"stream": stream, "scores": accumulated_scores},
|
1699 |
)
|
1700 |
multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
|
1701 |
+
|
|
|
1702 |
multi_stream = metric(multi_stream)
|
1703 |
+
|
1704 |
+
if i < len(metrics_list) - 1: # last metric
|
|
|
1705 |
accumulated_scores = []
|
1706 |
for inst in multi_stream["tmp"]:
|
1707 |
accumulated_scores.append(recursive_copy(inst["score"]))
|
|
|
2213 |
f"batch_size must be an integer equal to or greater than 1. "
|
2214 |
f"Got: {self.batch_size}."
|
2215 |
)
|
2216 |
+
|
2217 |
+
|
2218 |
+
class WikipediaFetcher(FieldOperator):
|
2219 |
+
mode: Literal["summary", "text"] = "text"
|
2220 |
+
_requirements_list = ["Wikipedia-API"]
|
2221 |
+
|
2222 |
+
def prepare(self):
|
2223 |
+
super().prepare()
|
2224 |
+
import wikipediaapi
|
2225 |
+
|
2226 |
+
self.wikipedia = wikipediaapi.Wikipedia("Unitxt")
|
2227 |
+
|
2228 |
+
def process_value(self, value: Any) -> Any:
|
2229 |
+
title = value.split("/")[-1]
|
2230 |
+
page = self.wikipedia.page(title)
|
2231 |
+
|
2232 |
+
return {"title": page.title, "body": getattr(page, self.mode)}
|
serializers.py
CHANGED
@@ -7,7 +7,7 @@ from .dataclass import AbstractField, Field
|
|
7 |
from .operators import InstanceFieldOperator
|
8 |
from .settings_utils import get_constants
|
9 |
from .type_utils import isoftype, to_type_string
|
10 |
-
from .types import Dialog, Image, Number, Table, Video
|
11 |
|
12 |
constants = get_constants()
|
13 |
|
@@ -127,9 +127,28 @@ class VideoSerializer(ImageSerializer):
|
|
127 |
return "".join(serialized_images)
|
128 |
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
class MultiTypeSerializer(Serializer):
|
131 |
serializers: List[SingleTypeSerializer] = Field(
|
132 |
default_factory=lambda: [
|
|
|
|
|
133 |
ImageSerializer(),
|
134 |
VideoSerializer(),
|
135 |
TableSerializer(),
|
|
|
7 |
from .operators import InstanceFieldOperator
|
8 |
from .settings_utils import get_constants
|
9 |
from .type_utils import isoftype, to_type_string
|
10 |
+
from .types import Dialog, Document, Image, MultiDocument, Number, Table, Video
|
11 |
|
12 |
constants = get_constants()
|
13 |
|
|
|
127 |
return "".join(serialized_images)
|
128 |
|
129 |
|
130 |
+
class DocumentSerializer(SingleTypeSerializer):
|
131 |
+
serialized_type = Document
|
132 |
+
|
133 |
+
def serialize(self, value: Document, instance: Dict[str, Any]) -> str:
|
134 |
+
return f"# {value['title']}\n\n{value['body']}"
|
135 |
+
|
136 |
+
|
137 |
+
class MultiDocumentSerializer(DocumentSerializer):
|
138 |
+
serialized_type = MultiDocument
|
139 |
+
|
140 |
+
def serialize(self, value: MultiDocument, instance: Dict[str, Any]) -> str:
|
141 |
+
documents = []
|
142 |
+
for document in value:
|
143 |
+
documents.append(super().serialize(document, instance))
|
144 |
+
return "\n\n".join(documents)
|
145 |
+
|
146 |
+
|
147 |
class MultiTypeSerializer(Serializer):
|
148 |
serializers: List[SingleTypeSerializer] = Field(
|
149 |
default_factory=lambda: [
|
150 |
+
DocumentSerializer(),
|
151 |
+
MultiDocumentSerializer(),
|
152 |
ImageSerializer(),
|
153 |
VideoSerializer(),
|
154 |
TableSerializer(),
|
task.py
CHANGED
@@ -116,13 +116,18 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
116 |
self.prediction_type
|
117 |
)
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
def task_deprecations(self):
|
120 |
if hasattr(self, "inputs") and self.inputs is not None:
|
121 |
depr_message = (
|
122 |
"The 'inputs' field is deprecated. Please use 'input_fields' instead."
|
123 |
)
|
124 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
125 |
-
|
126 |
if hasattr(self, "outputs") and self.outputs is not None:
|
127 |
depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
|
128 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
|
|
116 |
self.prediction_type
|
117 |
)
|
118 |
|
119 |
+
if hasattr(self, "inputs") and self.inputs is not None:
|
120 |
+
self.inputs = self.input_fields
|
121 |
+
|
122 |
+
if hasattr(self, "outputs") and self.outputs is not None:
|
123 |
+
self.outputs = self.reference_fields
|
124 |
+
|
125 |
def task_deprecations(self):
|
126 |
if hasattr(self, "inputs") and self.inputs is not None:
|
127 |
depr_message = (
|
128 |
"The 'inputs' field is deprecated. Please use 'input_fields' instead."
|
129 |
)
|
130 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
|
|
131 |
if hasattr(self, "outputs") and self.outputs is not None:
|
132 |
depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
|
133 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
templates.py
CHANGED
@@ -495,7 +495,31 @@ class PairwiseComparativeRatingTemplate(InputOutputTemplate):
|
|
495 |
|
496 |
|
497 |
class MultipleChoiceTemplate(InputFormatTemplate):
|
498 |
-
"""Formats the input
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
target_prefix: str = ""
|
501 |
choices_field: str = "choices"
|
@@ -504,7 +528,13 @@ class MultipleChoiceTemplate(InputFormatTemplate):
|
|
504 |
source_choice_format: str = "{choice_numeral}. {choice_text}"
|
505 |
target_choice_format: str = "{choice_numeral}"
|
506 |
enumerator: str = "capitals"
|
|
|
507 |
shuffle_choices: bool = False
|
|
|
|
|
|
|
|
|
|
|
508 |
|
509 |
def prepare(self):
|
510 |
super().prepare()
|
@@ -538,6 +568,31 @@ class MultipleChoiceTemplate(InputFormatTemplate):
|
|
538 |
"XX",
|
539 |
]
|
540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str:
|
542 |
choices = data[self.choices_field]
|
543 |
enumrated_choices = []
|
@@ -612,18 +667,44 @@ class MultipleChoiceTemplate(InputFormatTemplate):
|
|
612 |
def preprocess_input_and_reference_fields(
|
613 |
self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
|
614 |
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
615 |
-
if
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
|
621 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
622 |
random_generator.shuffle(choices)
|
623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
|
625 |
-
|
626 |
-
|
|
|
|
|
627 |
|
628 |
return input_fields, reference_fields
|
629 |
|
|
|
495 |
|
496 |
|
497 |
class MultipleChoiceTemplate(InputFormatTemplate):
|
498 |
+
"""Formats the input that specifies a multiple-choice question, with a list of possible answers to choose from, and identifies the correct answer.
|
499 |
+
|
500 |
+
Args:
|
501 |
+
target_prefix (str): Optional prefix that can be added before the target label in
|
502 |
+
generated prompts or outputs.
|
503 |
+
choices_field (str): The key under which the multiple choices are stored in the
|
504 |
+
input and reference dictionaries.
|
505 |
+
target_field (str): The key under which the correct choice is stored in the
|
506 |
+
reference dictionary (can be integer index or textual label).
|
507 |
+
choices_separator (str): A string used to join formatted choices (e.g. ", ").
|
508 |
+
source_choice_format (str): A Python format string used for displaying each choice
|
509 |
+
in the input fields (e.g. "{choice_numeral}. {choice_text}").
|
510 |
+
target_choice_format (str): A Python format string used for displaying each choice
|
511 |
+
in the target or final output (e.g. "{choice_numeral}").
|
512 |
+
enumerator (str): Determines how choice numerals are enumerated. Possible values
|
513 |
+
include "capitals", "lowercase", "numbers", or "roman".
|
514 |
+
shuffle_choices (bool): If True, shuffle the choices. The shuffling seed can be
|
515 |
+
set with `shuffle_choices_seed`.
|
516 |
+
shuffle_choices_seed (int, optional): If provided, the choices are shuffled with
|
517 |
+
this fixed integer seed for reproducibility.
|
518 |
+
sort_choices_by_length (bool): If True, sorts choices by their length (ascending).
|
519 |
+
sort_choices_alphabetically (bool): If True, sorts choices in alphabetical order.
|
520 |
+
reverse_choices (bool): If True, reverses the order of the choices after any
|
521 |
+
sorting has been applied. Defaults to False to preserve backward compatibility.
|
522 |
+
"""
|
523 |
|
524 |
target_prefix: str = ""
|
525 |
choices_field: str = "choices"
|
|
|
528 |
source_choice_format: str = "{choice_numeral}. {choice_text}"
|
529 |
target_choice_format: str = "{choice_numeral}"
|
530 |
enumerator: str = "capitals"
|
531 |
+
|
532 |
shuffle_choices: bool = False
|
533 |
+
shuffle_choices_seed: int = None
|
534 |
+
sort_choices_by_length: bool = False
|
535 |
+
sort_choices_alphabetically: bool = False
|
536 |
+
reverse_choices: bool = False # False by default for backward-compat
|
537 |
+
place_correct_choice_position: int = None
|
538 |
|
539 |
def prepare(self):
|
540 |
super().prepare()
|
|
|
568 |
"XX",
|
569 |
]
|
570 |
|
571 |
+
def verify(self):
|
572 |
+
super().verify()
|
573 |
+
if self.shuffle_choices and (
|
574 |
+
self.sort_choices_by_length
|
575 |
+
or self.sort_choices_alphabetically
|
576 |
+
or self.reverse_choices
|
577 |
+
or self.place_correct_choice_position is not None
|
578 |
+
):
|
579 |
+
raise UnitxtError(
|
580 |
+
"You cannot combine shuffle_choices with sorting or reversing flags."
|
581 |
+
)
|
582 |
+
|
583 |
+
if self.sort_choices_by_length and self.sort_choices_alphabetically:
|
584 |
+
raise UnitxtError(
|
585 |
+
"You cannot combine both sort_choices_by_length and sort_choices_alphabetically simultaneously."
|
586 |
+
)
|
587 |
+
if self.place_correct_choice_position is not None and (
|
588 |
+
self.sort_choices_by_length
|
589 |
+
or self.sort_choices_alphabetically
|
590 |
+
or self.reverse_choices
|
591 |
+
):
|
592 |
+
raise UnitxtError(
|
593 |
+
"You cannot combine place_correct_choice_position with sorting or reversing flags."
|
594 |
+
)
|
595 |
+
|
596 |
def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str:
|
597 |
choices = data[self.choices_field]
|
598 |
enumrated_choices = []
|
|
|
667 |
def preprocess_input_and_reference_fields(
|
668 |
self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
|
669 |
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
670 |
+
if (
|
671 |
+
not self.shuffle_choices
|
672 |
+
and not self.sort_choices_by_length
|
673 |
+
and not self.sort_choices_alphabetically
|
674 |
+
and not self.reverse_choices
|
675 |
+
and self.place_correct_choice_position is None
|
676 |
+
):
|
677 |
+
return input_fields, reference_fields
|
678 |
+
|
679 |
+
choices = input_fields[self.choices_field]
|
680 |
+
target_index = self.outputs_to_target_index(reference_fields)
|
681 |
+
original_label_choice = reference_fields[self.choices_field][target_index]
|
682 |
|
683 |
+
if self.sort_choices_by_length:
|
684 |
+
choices.sort(key=len)
|
685 |
+
if self.sort_choices_alphabetically:
|
686 |
+
choices.sort()
|
687 |
+
if self.reverse_choices:
|
688 |
+
choices.reverse()
|
689 |
+
if self.shuffle_choices:
|
690 |
+
random_generator = new_random_generator(
|
691 |
+
self.shuffle_choices_seed
|
692 |
+
if self.shuffle_choices_seed is not None
|
693 |
+
else {**input_fields}
|
694 |
+
)
|
695 |
random_generator.shuffle(choices)
|
696 |
+
if self.place_correct_choice_position is not None:
|
697 |
+
if not 0 <= self.place_correct_choice_position < len(choices):
|
698 |
+
raise ValueError(
|
699 |
+
f"fix_correct_choice_position={self.place_correct_choice_position} out of range (0..{len(choices) - 1})."
|
700 |
+
)
|
701 |
+
choices.remove(original_label_choice)
|
702 |
+
choices.insert(self.place_correct_choice_position, original_label_choice)
|
703 |
|
704 |
+
# Update both input_fields and reference_fields once at the end
|
705 |
+
input_fields[self.choices_field] = choices
|
706 |
+
reference_fields[self.choices_field] = choices
|
707 |
+
reference_fields[self.target_field] = choices.index(original_label_choice)
|
708 |
|
709 |
return input_fields, reference_fields
|
710 |
|
types.py
CHANGED
@@ -26,6 +26,13 @@ class Image(TypedDict):
|
|
26 |
format: str
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
Video = NewType("Video", List[Image])
|
30 |
|
31 |
|
@@ -46,4 +53,6 @@ register_type(Table)
|
|
46 |
register_type(Audio)
|
47 |
register_type(Image)
|
48 |
register_type(Video)
|
|
|
|
|
49 |
register_type(RagResponse)
|
|
|
26 |
format: str
|
27 |
|
28 |
|
29 |
+
class Document(TypedDict):
|
30 |
+
title: str
|
31 |
+
body: str
|
32 |
+
|
33 |
+
|
34 |
+
MultiDocument = NewType("MultiDocument", List[Document])
|
35 |
+
|
36 |
Video = NewType("Video", List[Image])
|
37 |
|
38 |
|
|
|
53 |
register_type(Audio)
|
54 |
register_type(Image)
|
55 |
register_type(Video)
|
56 |
+
register_type(Document)
|
57 |
+
register_type(MultiDocument)
|
58 |
register_type(RagResponse)
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.16.
|
|
|
1 |
+
version = "1.16.2"
|