Upload folder using huggingface_hub
Browse files- README.md +1 -1
- api.py +10 -10
- benchmark.py +2 -2
- blocks.py +1 -1
- card.py +11 -10
- dataset_utils.py +2 -2
- deprecation_utils.py +14 -9
- dialog_operators.py +22 -12
- error_utils.py +12 -8
- image_operators.py +7 -5
- inference.py +86 -48
- llm_as_judge.py +23 -23
- llm_as_judge_constants.py +5 -5
- llm_as_judge_utils.py +3 -1
- loaders.py +72 -76
- metric_utils.py +6 -9
- metrics.py +66 -43
- operator.py +34 -17
- operators.py +144 -124
- processors.py +13 -0
- schema.py +3 -0
- settings_utils.py +2 -1
- span_lableing_operators.py +9 -10
- splitters.py +31 -34
- standard.py +318 -117
- struct_data_operators.py +1 -1
- task.py +24 -17
- templates.py +33 -31
- type_utils.py +3 -0
- version.py +1 -1
README.md
CHANGED
@@ -30,7 +30,7 @@ In the dynamic landscape of generative NLP, traditional text processing pipeline
|
|
30 |
![license](https://img.shields.io/github/license/ibm/unitxt)
|
31 |
![python](https://img.shields.io/badge/python-3.8%20|%203.9-blue)
|
32 |
![tests](https://img.shields.io/github/actions/workflow/status/ibm/unitxt/library_tests.yml?branch=main&label=tests)
|
33 |
-
[![
|
34 |
![Read the Docs](https://img.shields.io/readthedocs/unitxt)
|
35 |
[![downloads](https://static.pepy.tech/personalized-badge/unitxt?period=total&units=international_system&left_color=grey&right_color=green&left_text=downloads)](https://pepy.tech/project/unitxt)
|
36 |
|
|
|
30 |
![license](https://img.shields.io/github/license/ibm/unitxt)
|
31 |
![python](https://img.shields.io/badge/python-3.8%20|%203.9-blue)
|
32 |
![tests](https://img.shields.io/github/actions/workflow/status/ibm/unitxt/library_tests.yml?branch=main&label=tests)
|
33 |
+
[![Coverage Status](https://coveralls.io/repos/github/IBM/unitxt/badge.svg)](https://coveralls.io/github/IBM/unitxt)
|
34 |
![Read the Docs](https://img.shields.io/readthedocs/unitxt)
|
35 |
[![downloads](https://static.pepy.tech/personalized-badge/unitxt?period=total&units=international_system&left_color=grey&right_color=green&left_text=downloads)](https://pepy.tech/project/unitxt)
|
36 |
|
api.py
CHANGED
@@ -18,7 +18,7 @@ from .metric_utils import EvaluationResults, _compute, _inference_post_process
|
|
18 |
from .operator import SourceOperator
|
19 |
from .schema import UNITXT_DATASET_SCHEMA, loads_instance
|
20 |
from .settings_utils import get_constants, get_settings
|
21 |
-
from .standard import
|
22 |
from .task import Task
|
23 |
|
24 |
logger = get_logger()
|
@@ -35,7 +35,7 @@ def load(source: Union[SourceOperator, str]):
|
|
35 |
return source().to_dataset()
|
36 |
|
37 |
|
38 |
-
def _get_recipe_from_query(dataset_query: str) ->
|
39 |
dataset_query = dataset_query.replace("sys_prompt", "instruction")
|
40 |
try:
|
41 |
dataset_stream, _ = fetch_artifact(dataset_query)
|
@@ -44,14 +44,14 @@ def _get_recipe_from_query(dataset_query: str) -> StandardRecipe:
|
|
44 |
return dataset_stream
|
45 |
|
46 |
|
47 |
-
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) ->
|
48 |
-
recipe_attributes = list(
|
49 |
for param in dataset_params.keys():
|
50 |
assert param in recipe_attributes, (
|
51 |
-
f"The parameter '{param}' is not an attribute of the '
|
52 |
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
|
53 |
)
|
54 |
-
return
|
55 |
|
56 |
|
57 |
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
|
@@ -76,8 +76,8 @@ def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None)
|
|
76 |
)
|
77 |
|
78 |
|
79 |
-
def load_recipe(dataset_query: Optional[str] = None, **kwargs) ->
|
80 |
-
if isinstance(dataset_query,
|
81 |
return dataset_query
|
82 |
|
83 |
_verify_dataset_args(dataset_query, kwargs)
|
@@ -230,7 +230,7 @@ def infer(
|
|
230 |
return_data: bool = False,
|
231 |
return_log_probs: bool = False,
|
232 |
return_meta_data: bool = False,
|
233 |
-
previous_messages: Optional[
|
234 |
**kwargs,
|
235 |
):
|
236 |
dataset = produce(instance_or_instances, dataset_query, **kwargs)
|
@@ -283,7 +283,7 @@ def select(
|
|
283 |
engine: OptionSelectingByLogProbsInferenceEngine,
|
284 |
dataset_query: Optional[str] = None,
|
285 |
return_data: bool = False,
|
286 |
-
previous_messages: Optional[
|
287 |
**kwargs,
|
288 |
):
|
289 |
dataset = produce(instance_or_instances, dataset_query, **kwargs)
|
|
|
18 |
from .operator import SourceOperator
|
19 |
from .schema import UNITXT_DATASET_SCHEMA, loads_instance
|
20 |
from .settings_utils import get_constants, get_settings
|
21 |
+
from .standard import DatasetRecipe
|
22 |
from .task import Task
|
23 |
|
24 |
logger = get_logger()
|
|
|
35 |
return source().to_dataset()
|
36 |
|
37 |
|
38 |
+
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
|
39 |
dataset_query = dataset_query.replace("sys_prompt", "instruction")
|
40 |
try:
|
41 |
dataset_stream, _ = fetch_artifact(dataset_query)
|
|
|
44 |
return dataset_stream
|
45 |
|
46 |
|
47 |
+
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
|
48 |
+
recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
|
49 |
for param in dataset_params.keys():
|
50 |
assert param in recipe_attributes, (
|
51 |
+
f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
|
52 |
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
|
53 |
)
|
54 |
+
return DatasetRecipe(**dataset_params)
|
55 |
|
56 |
|
57 |
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
|
|
|
76 |
)
|
77 |
|
78 |
|
79 |
+
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
|
80 |
+
if isinstance(dataset_query, DatasetRecipe):
|
81 |
return dataset_query
|
82 |
|
83 |
_verify_dataset_args(dataset_query, kwargs)
|
|
|
230 |
return_data: bool = False,
|
231 |
return_log_probs: bool = False,
|
232 |
return_meta_data: bool = False,
|
233 |
+
previous_messages: Optional[List[Dict[str, str]]] = None,
|
234 |
**kwargs,
|
235 |
):
|
236 |
dataset = produce(instance_or_instances, dataset_query, **kwargs)
|
|
|
283 |
engine: OptionSelectingByLogProbsInferenceEngine,
|
284 |
dataset_query: Optional[str] = None,
|
285 |
return_data: bool = False,
|
286 |
+
previous_messages: Optional[List[Dict[str, str]]] = None,
|
287 |
**kwargs,
|
288 |
):
|
289 |
dataset = produce(instance_or_instances, dataset_query, **kwargs)
|
benchmark.py
CHANGED
@@ -5,7 +5,7 @@ from .dataclass import NonPositionalField
|
|
5 |
from .formats import Format
|
6 |
from .fusion import FixedFusion, WeightedFusion
|
7 |
from .operator import SourceOperator
|
8 |
-
from .standard import
|
9 |
from .stream import MultiStream
|
10 |
from .system_prompts import SystemPrompt
|
11 |
|
@@ -22,7 +22,7 @@ class BaseBenchmark(SourceOperator):
|
|
22 |
|
23 |
|
24 |
class Benchmark(BaseBenchmark):
|
25 |
-
subsets: Dict[str, Union[
|
26 |
|
27 |
max_total_samples: int = None
|
28 |
max_samples_per_subset: int = None
|
|
|
5 |
from .formats import Format
|
6 |
from .fusion import FixedFusion, WeightedFusion
|
7 |
from .operator import SourceOperator
|
8 |
+
from .standard import DatasetRecipe
|
9 |
from .stream import MultiStream
|
10 |
from .system_prompts import SystemPrompt
|
11 |
|
|
|
22 |
|
23 |
|
24 |
class Benchmark(BaseBenchmark):
|
25 |
+
subsets: Dict[str, Union[DatasetRecipe, BaseBenchmark]]
|
26 |
|
27 |
max_total_samples: int = None
|
28 |
max_samples_per_subset: int = None
|
blocks.py
CHANGED
@@ -18,7 +18,7 @@ from .operators import (
|
|
18 |
)
|
19 |
from .processors import ToString, ToStringStripped
|
20 |
from .recipe import SequentialRecipe
|
21 |
-
from .splitters import
|
22 |
from .stream import MultiStream
|
23 |
from .struct_data_operators import (
|
24 |
ConstructTableFromRowsCols,
|
|
|
18 |
)
|
19 |
from .processors import ToString, ToStringStripped
|
20 |
from .recipe import SequentialRecipe
|
21 |
+
from .splitters import AssignDemosToInstance, RandomSampler, SliceSplit, SplitRandomMix
|
22 |
from .stream import MultiStream
|
23 |
from .struct_data_operators import (
|
24 |
ConstructTableFromRowsCols,
|
card.py
CHANGED
@@ -12,16 +12,17 @@ from .templates import Template, TemplatesDict, TemplatesList
|
|
12 |
class TaskCard(Artifact):
|
13 |
"""TaskCard delineates the phases in transforming the source dataset into model input, and specifies the metrics for evaluation of model output.
|
14 |
|
15 |
-
|
16 |
-
loader:
|
17 |
-
|
18 |
-
preprocess_steps:
|
19 |
-
|
20 |
-
task:
|
21 |
-
|
22 |
-
templates:
|
23 |
-
|
24 |
-
default_template:
|
|
|
25 |
"""
|
26 |
|
27 |
loader: Loader
|
|
|
12 |
class TaskCard(Artifact):
|
13 |
"""TaskCard delineates the phases in transforming the source dataset into model input, and specifies the metrics for evaluation of model output.
|
14 |
|
15 |
+
Args:
|
16 |
+
loader:
|
17 |
+
specifies the source address and the loading operator that can access that source and transform it into a unitxt multistream.
|
18 |
+
preprocess_steps:
|
19 |
+
list of unitxt operators to process the data source into model input.
|
20 |
+
task:
|
21 |
+
specifies the fields (of the already (pre)processed instance) making the inputs, the fields making the outputs, and the metrics to be used for evaluating the model output.
|
22 |
+
templates:
|
23 |
+
format strings to be applied on the input fields (specified by the task) and the output fields. The template also carries the instructions and the list of postprocessing steps, to be applied to the model output.
|
24 |
+
default_template:
|
25 |
+
a default template for tasks with very specific task dataset specific template
|
26 |
"""
|
27 |
|
28 |
loader: Loader
|
dataset_utils.py
CHANGED
@@ -5,7 +5,7 @@ from .logging_utils import get_logger
|
|
5 |
from .parsing_utils import parse_key_equals_value_string_to_dict
|
6 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
7 |
from .settings_utils import get_settings
|
8 |
-
from .standard import
|
9 |
|
10 |
logger = get_logger()
|
11 |
settings = get_settings()
|
@@ -24,7 +24,7 @@ def parse(query: str):
|
|
24 |
|
25 |
|
26 |
def get_dataset_artifact(dataset):
|
27 |
-
if isinstance(dataset,
|
28 |
return dataset
|
29 |
assert isinstance(
|
30 |
dataset, str
|
|
|
5 |
from .parsing_utils import parse_key_equals_value_string_to_dict
|
6 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
7 |
from .settings_utils import get_settings
|
8 |
+
from .standard import DatasetRecipe
|
9 |
|
10 |
logger = get_logger()
|
11 |
settings = get_settings()
|
|
|
24 |
|
25 |
|
26 |
def get_dataset_artifact(dataset):
|
27 |
+
if isinstance(dataset, DatasetRecipe):
|
28 |
return dataset
|
29 |
assert isinstance(
|
30 |
dataset, str
|
deprecation_utils.py
CHANGED
@@ -18,19 +18,24 @@ def compare_versions(version1, version2):
|
|
18 |
"""Compare two semantic versioning strings and determine their relationship.
|
19 |
|
20 |
Parameters:
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
|
24 |
Returns:
|
25 |
-
|
26 |
|
27 |
Example:
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
34 |
"""
|
35 |
parts1 = [int(part) for part in version1.split(".")]
|
36 |
parts2 = [int(part) for part in version2.split(".")]
|
|
|
18 |
"""Compare two semantic versioning strings and determine their relationship.
|
19 |
|
20 |
Parameters:
|
21 |
+
version1 (str):
|
22 |
+
The first version string to compare.
|
23 |
+
version2 (str):
|
24 |
+
The second version string to compare.
|
25 |
|
26 |
Returns:
|
27 |
+
int: -1 if version1 < version2, 1 if version1 > version2, 0 if equal.
|
28 |
|
29 |
Example:
|
30 |
+
.. code-block:: text
|
31 |
+
|
32 |
+
>>> compare_versions("1.2.0", "1.2.3")
|
33 |
+
-1
|
34 |
+
>>> compare_versions("1.3.0", "1.2.8")
|
35 |
+
1
|
36 |
+
>>> compare_versions("1.0.0", "1.0.0")
|
37 |
+
0
|
38 |
+
|
39 |
"""
|
40 |
parts1 = [int(part) for part in version1.split(".")]
|
41 |
parts2 = [int(part) for part in version2.split(".")]
|
dialog_operators.py
CHANGED
@@ -27,12 +27,17 @@ class SerializeDialog(InstanceFieldOperator):
|
|
27 |
of system responses and can operate on a per-turn basis or aggregate the entire
|
28 |
dialog.
|
29 |
|
30 |
-
|
31 |
-
field (str):
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
"""
|
37 |
|
38 |
format: SystemFormat = None
|
@@ -100,12 +105,17 @@ class SerializeOpenAiFormatDialog(SerializeDialog):
|
|
100 |
of system responses and can operate on a per-turn basis or aggregate the entire
|
101 |
dialog.
|
102 |
|
103 |
-
|
104 |
-
field (str):
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
109 |
"""
|
110 |
|
111 |
is_last_turn_user_only: bool = True
|
|
|
27 |
of system responses and can operate on a per-turn basis or aggregate the entire
|
28 |
dialog.
|
29 |
|
30 |
+
Args:
|
31 |
+
field (str):
|
32 |
+
The field in the input data that contains the dialog.
|
33 |
+
to_field (Optional[str]):
|
34 |
+
The field in the output data where the serialized dialog will be stored.
|
35 |
+
last_user_turn_to_field (Optional[str]):
|
36 |
+
Field to store the last user turn.
|
37 |
+
last_system_turn_to_field (Optional[str]):
|
38 |
+
Field to store the last system turn.
|
39 |
+
context_field (Optional[str]):
|
40 |
+
Field that contains additional context to be prepended to the dialog.
|
41 |
"""
|
42 |
|
43 |
format: SystemFormat = None
|
|
|
105 |
of system responses and can operate on a per-turn basis or aggregate the entire
|
106 |
dialog.
|
107 |
|
108 |
+
Args:
|
109 |
+
field (str):
|
110 |
+
The field in the input data that contains the dialog.
|
111 |
+
to_field (Optional[str]):
|
112 |
+
The field in the output data where the serialized dialog will be stored.
|
113 |
+
last_user_turn_to_field (Optional[str]):
|
114 |
+
Field to store the last user turn.
|
115 |
+
last_system_turn_to_field (Optional[str]):
|
116 |
+
Field to store the last system turn.
|
117 |
+
context_field (Optional[str]):
|
118 |
+
Field that contains additional context to be prepended to the dialog.
|
119 |
"""
|
120 |
|
121 |
is_last_turn_user_only: bool = True
|
error_utils.py
CHANGED
@@ -27,10 +27,12 @@ def additional_info(path: str) -> str:
|
|
27 |
class UnitxtError(Exception):
|
28 |
"""Exception raised for Unitxt errors.
|
29 |
|
30 |
-
|
31 |
-
message :
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
"""
|
36 |
|
@@ -43,10 +45,12 @@ class UnitxtError(Exception):
|
|
43 |
class UnitxtWarning:
|
44 |
"""Object to format warning message to log.
|
45 |
|
46 |
-
|
47 |
-
message
|
48 |
-
|
49 |
-
|
|
|
|
|
50 |
"""
|
51 |
|
52 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
|
|
27 |
class UnitxtError(Exception):
|
28 |
"""Exception raised for Unitxt errors.
|
29 |
|
30 |
+
Args:
|
31 |
+
message (str):
|
32 |
+
explanation of the error
|
33 |
+
additional_info_id (Optional[str]):
|
34 |
+
relative path to additional documentation on web
|
35 |
+
If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
|
36 |
|
37 |
"""
|
38 |
|
|
|
45 |
class UnitxtWarning:
|
46 |
"""Object to format warning message to log.
|
47 |
|
48 |
+
Args:
|
49 |
+
message (str):
|
50 |
+
explanation of the warning
|
51 |
+
additional_info_id (Optional[str]):
|
52 |
+
relative path to additional documentation on web
|
53 |
+
If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
|
54 |
"""
|
55 |
|
56 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
image_operators.py
CHANGED
@@ -216,13 +216,15 @@ class GridLines(ImageAugmentor):
|
|
216 |
class PixelNoise(ImageAugmentor):
|
217 |
"""A class that overlays a mask of randomly colored nxn squares across an image based on a specified noise rate.
|
218 |
|
219 |
-
|
220 |
-
square_size (int):
|
221 |
-
|
222 |
-
noise_rate (float):
|
|
|
223 |
|
224 |
Methods:
|
225 |
-
process_image(image):
|
|
|
226 |
"""
|
227 |
|
228 |
square_size: int = 1
|
|
|
216 |
class PixelNoise(ImageAugmentor):
|
217 |
"""A class that overlays a mask of randomly colored nxn squares across an image based on a specified noise rate.
|
218 |
|
219 |
+
Args:
|
220 |
+
square_size (int):
|
221 |
+
Size of each square in pixels.
|
222 |
+
noise_rate (float):
|
223 |
+
Proportion of the image that should be affected by noise (0 to 1).
|
224 |
|
225 |
Methods:
|
226 |
+
process_image(image):
|
227 |
+
Adds the random square mask to the provided image and returns the modified image.
|
228 |
"""
|
229 |
|
230 |
square_size: int = 1
|
inference.py
CHANGED
@@ -9,6 +9,7 @@ import sys
|
|
9 |
import time
|
10 |
import uuid
|
11 |
from collections import Counter
|
|
|
12 |
from typing import (
|
13 |
Any,
|
14 |
Dict,
|
@@ -63,8 +64,8 @@ class StandardAPIParamsMixin(Artifact):
|
|
63 |
n: Optional[int] = None
|
64 |
parallel_tool_calls: Optional[bool] = None
|
65 |
service_tier: Optional[Literal["auto", "default"]] = None
|
66 |
-
credentials: Optional[
|
67 |
-
extra_headers: Optional[
|
68 |
|
69 |
|
70 |
def get_model_and_label_id(model_name, label):
|
@@ -1171,8 +1172,8 @@ class OptionSelectingByLogProbsInferenceEngine:
|
|
1171 |
for option in instance["task_data"]["options"]
|
1172 |
]
|
1173 |
|
1174 |
-
dataset_with_options_logprobs:
|
1175 |
-
|
1176 |
] = self.get_options_log_probs(dataset_with_options)
|
1177 |
|
1178 |
dataset_iterator = iter(dataset_with_options_logprobs)
|
@@ -1469,6 +1470,13 @@ class OpenAiInferenceEngineParams(Artifact):
|
|
1469 |
service_tier: Optional[Literal["auto", "default"]] = None
|
1470 |
|
1471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1472 |
class OpenAiInferenceEngine(
|
1473 |
InferenceEngine,
|
1474 |
LogProbInferenceEngine,
|
@@ -1485,6 +1493,7 @@ class OpenAiInferenceEngine(
|
|
1485 |
base_url: Optional[str] = None
|
1486 |
default_headers: Dict[str, str] = {}
|
1487 |
credentials: CredentialsOpenAi = {}
|
|
|
1488 |
|
1489 |
def get_engine_id(self) -> str:
|
1490 |
return get_model_and_label_id(self.model_name, self.label)
|
@@ -1528,52 +1537,76 @@ class OpenAiInferenceEngine(
|
|
1528 |
if v is not None
|
1529 |
}
|
1530 |
|
1531 |
-
def
|
1532 |
self,
|
1533 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
|
|
1534 |
return_meta_data: bool = False,
|
1535 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
|
|
1536 |
outputs = []
|
1537 |
-
|
1538 |
-
|
1539 |
-
|
1540 |
-
|
1541 |
-
|
1542 |
-
|
1543 |
-
|
1544 |
-
prediction = response.choices[0].message.content
|
1545 |
-
output = self.get_return_object(prediction, response, return_meta_data)
|
1546 |
-
|
1547 |
-
outputs.append(output)
|
1548 |
|
1549 |
return outputs
|
1550 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1551 |
def _infer_log_probs(
|
1552 |
self,
|
1553 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
1554 |
return_meta_data: bool = False,
|
1555 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
1556 |
-
|
1557 |
-
|
1558 |
-
|
1559 |
-
|
1560 |
-
|
1561 |
-
|
1562 |
-
|
1563 |
-
|
1564 |
-
|
1565 |
-
|
1566 |
-
|
1567 |
-
|
1568 |
-
|
1569 |
-
|
1570 |
-
|
1571 |
-
|
1572 |
-
|
1573 |
-
|
1574 |
-
|
1575 |
-
|
1576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1577 |
|
1578 |
def get_return_object(self, predict_result, response, return_meta_data):
|
1579 |
if return_meta_data:
|
@@ -1807,16 +1840,19 @@ class WMLInferenceEngineBase(
|
|
1807 |
):
|
1808 |
"""Base for classes running inference using ibm-watsonx-ai.
|
1809 |
|
1810 |
-
|
1811 |
-
credentials (Dict[str, str], optional):
|
|
|
1812 |
instance which tries to retrieve proper environment variables
|
1813 |
("WML_URL", "WML_PROJECT_ID", "WML_SPACE_ID", "WML_APIKEY", "WML_USERNAME", "WML_PASSWORD").
|
1814 |
However, a dictionary with the following keys: "url", "apikey", "project_id", "space_id",
|
1815 |
"username", "password".
|
1816 |
can be directly provided instead.
|
1817 |
-
model_name (str, optional):
|
|
|
1818 |
exclusive with 'deployment_id'.
|
1819 |
-
deployment_id (str, optional):
|
|
|
1820 |
inference. Mutually exclusive with 'model_name'.
|
1821 |
parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
|
1822 |
Defines inference parameters and their values. Deprecated attribute, please pass respective
|
@@ -2077,9 +2113,10 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
|
|
2077 |
|
2078 |
If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
|
2079 |
|
2080 |
-
|
2081 |
-
concurrency_limit (int):
|
2082 |
-
|
|
|
2083 |
|
2084 |
Examples:
|
2085 |
.. code-block:: python
|
@@ -2207,10 +2244,11 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2207 |
concatenate images within an instance into a single image and adjust your query
|
2208 |
accordingly (if necessary).
|
2209 |
|
2210 |
-
|
2211 |
-
image_encoder (EncodeImageToString, optional):
|
2212 |
-
|
2213 |
-
|
|
|
2214 |
|
2215 |
Example:
|
2216 |
.. code-block:: python
|
|
|
9 |
import time
|
10 |
import uuid
|
11 |
from collections import Counter
|
12 |
+
from multiprocessing.pool import ThreadPool
|
13 |
from typing import (
|
14 |
Any,
|
15 |
Dict,
|
|
|
64 |
n: Optional[int] = None
|
65 |
parallel_tool_calls: Optional[bool] = None
|
66 |
service_tier: Optional[Literal["auto", "default"]] = None
|
67 |
+
credentials: Optional[Dict[str, str]] = {}
|
68 |
+
extra_headers: Optional[Dict[str, str]] = None
|
69 |
|
70 |
|
71 |
def get_model_and_label_id(model_name, label):
|
|
|
1172 |
for option in instance["task_data"]["options"]
|
1173 |
]
|
1174 |
|
1175 |
+
dataset_with_options_logprobs: List[
|
1176 |
+
List[Dict[str, Union[float, str]]]
|
1177 |
] = self.get_options_log_probs(dataset_with_options)
|
1178 |
|
1179 |
dataset_iterator = iter(dataset_with_options_logprobs)
|
|
|
1470 |
service_tier: Optional[Literal["auto", "default"]] = None
|
1471 |
|
1472 |
|
1473 |
+
def run_with_imap(func):
|
1474 |
+
def inner(self, args):
|
1475 |
+
return func(self, *args)
|
1476 |
+
|
1477 |
+
return inner
|
1478 |
+
|
1479 |
+
|
1480 |
class OpenAiInferenceEngine(
|
1481 |
InferenceEngine,
|
1482 |
LogProbInferenceEngine,
|
|
|
1493 |
base_url: Optional[str] = None
|
1494 |
default_headers: Dict[str, str] = {}
|
1495 |
credentials: CredentialsOpenAi = {}
|
1496 |
+
num_parallel_requests: int = 20
|
1497 |
|
1498 |
def get_engine_id(self) -> str:
|
1499 |
return get_model_and_label_id(self.model_name, self.label)
|
|
|
1537 |
if v is not None
|
1538 |
}
|
1539 |
|
1540 |
+
def _parallel_infer(
|
1541 |
self,
|
1542 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
1543 |
+
infer_func,
|
1544 |
return_meta_data: bool = False,
|
1545 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1546 |
+
inputs = [(instance, return_meta_data) for instance in dataset]
|
1547 |
outputs = []
|
1548 |
+
with ThreadPool(processes=self.num_parallel_requests) as pool:
|
1549 |
+
for output in tqdm(
|
1550 |
+
pool.imap(infer_func, inputs),
|
1551 |
+
total=len(inputs),
|
1552 |
+
desc=f"Inferring with {self.__class__.__name__}",
|
1553 |
+
):
|
1554 |
+
outputs.append(output)
|
|
|
|
|
|
|
|
|
1555 |
|
1556 |
return outputs
|
1557 |
|
1558 |
+
def _infer(
|
1559 |
+
self,
|
1560 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1561 |
+
return_meta_data: bool = False,
|
1562 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1563 |
+
return self._parallel_infer(
|
1564 |
+
dataset=dataset,
|
1565 |
+
return_meta_data=return_meta_data,
|
1566 |
+
infer_func=self._get_chat_completion,
|
1567 |
+
)
|
1568 |
+
|
1569 |
def _infer_log_probs(
|
1570 |
self,
|
1571 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
1572 |
return_meta_data: bool = False,
|
1573 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
1574 |
+
return self._parallel_infer(
|
1575 |
+
dataset=dataset,
|
1576 |
+
return_meta_data=return_meta_data,
|
1577 |
+
infer_func=self._get_logprobs,
|
1578 |
+
)
|
1579 |
+
|
1580 |
+
@run_with_imap
|
1581 |
+
def _get_chat_completion(self, instance, return_meta_data):
|
1582 |
+
messages = self.to_messages(instance)
|
1583 |
+
response = self.client.chat.completions.create(
|
1584 |
+
messages=messages,
|
1585 |
+
model=self.model_name,
|
1586 |
+
**self._get_completion_kwargs(),
|
1587 |
+
)
|
1588 |
+
prediction = response.choices[0].message.content
|
1589 |
+
return self.get_return_object(prediction, response, return_meta_data)
|
1590 |
+
|
1591 |
+
@run_with_imap
|
1592 |
+
def _get_logprobs(self, instance, return_meta_data):
|
1593 |
+
messages = self.to_messages(instance)
|
1594 |
+
response = self.client.chat.completions.create(
|
1595 |
+
messages=messages,
|
1596 |
+
model=self.model_name,
|
1597 |
+
**self._get_completion_kwargs(),
|
1598 |
+
)
|
1599 |
+
top_logprobs_response = response.choices[0].logprobs.content
|
1600 |
+
pred_output = [
|
1601 |
+
{
|
1602 |
+
"top_tokens": [
|
1603 |
+
{"text": obj.token, "logprob": obj.logprob}
|
1604 |
+
for obj in generated_token.top_logprobs
|
1605 |
+
]
|
1606 |
+
}
|
1607 |
+
for generated_token in top_logprobs_response
|
1608 |
+
]
|
1609 |
+
return self.get_return_object(pred_output, response, return_meta_data)
|
1610 |
|
1611 |
def get_return_object(self, predict_result, response, return_meta_data):
|
1612 |
if return_meta_data:
|
|
|
1840 |
):
|
1841 |
"""Base for classes running inference using ibm-watsonx-ai.
|
1842 |
|
1843 |
+
Args:
|
1844 |
+
credentials (Dict[str, str], optional):
|
1845 |
+
By default, it is created by a class
|
1846 |
instance which tries to retrieve proper environment variables
|
1847 |
("WML_URL", "WML_PROJECT_ID", "WML_SPACE_ID", "WML_APIKEY", "WML_USERNAME", "WML_PASSWORD").
|
1848 |
However, a dictionary with the following keys: "url", "apikey", "project_id", "space_id",
|
1849 |
"username", "password".
|
1850 |
can be directly provided instead.
|
1851 |
+
model_name (str, optional):
|
1852 |
+
ID of a model to be used for inference. Mutually
|
1853 |
exclusive with 'deployment_id'.
|
1854 |
+
deployment_id (str, optional):
|
1855 |
+
Deployment ID of a tuned model to be used for
|
1856 |
inference. Mutually exclusive with 'model_name'.
|
1857 |
parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
|
1858 |
Defines inference parameters and their values. Deprecated attribute, please pass respective
|
|
|
2113 |
|
2114 |
If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
|
2115 |
|
2116 |
+
Args:
|
2117 |
+
concurrency_limit (int):
|
2118 |
+
Number of concurrent requests sent to a model. Default is 10,
|
2119 |
+
which is also the maximum value.
|
2120 |
|
2121 |
Examples:
|
2122 |
.. code-block:: python
|
|
|
2244 |
concatenate images within an instance into a single image and adjust your query
|
2245 |
accordingly (if necessary).
|
2246 |
|
2247 |
+
Args:
|
2248 |
+
image_encoder (EncodeImageToString, optional):
|
2249 |
+
operator which encodes images in
|
2250 |
+
given format to base64 strings required by service. You should specify it when
|
2251 |
+
you are using images in your inputs.
|
2252 |
|
2253 |
Example:
|
2254 |
.. code-block:: python
|
llm_as_judge.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import itertools
|
2 |
from difflib import get_close_matches
|
3 |
-
from typing import List, Optional, Union
|
4 |
|
5 |
from .api import infer
|
6 |
from .artifact import fetch_artifact
|
@@ -145,7 +145,7 @@ class LLMJudge(BulkInstanceMetric):
|
|
145 |
)
|
146 |
return
|
147 |
|
148 |
-
def get_contexts(self, task_data:
|
149 |
return [
|
150 |
get_parsed_context(
|
151 |
{
|
@@ -161,7 +161,7 @@ class LLMJudge(BulkInstanceMetric):
|
|
161 |
instances: list,
|
162 |
task: Task,
|
163 |
template: Template,
|
164 |
-
previous_messages: Optional[
|
165 |
):
|
166 |
outputs_dataset = infer(
|
167 |
instances,
|
@@ -172,11 +172,11 @@ class LLMJudge(BulkInstanceMetric):
|
|
172 |
return_data=True,
|
173 |
previous_messages=previous_messages,
|
174 |
)
|
175 |
-
prompts:
|
176 |
-
raw_predictions:
|
177 |
instance["raw_prediction"] for instance in outputs_dataset
|
178 |
]
|
179 |
-
predictions:
|
180 |
instance["prediction"] for instance in outputs_dataset
|
181 |
]
|
182 |
return (prompts, raw_predictions, predictions)
|
@@ -274,7 +274,7 @@ class LLMJudgeDirect(LLMJudge):
|
|
274 |
raise Exception(
|
275 |
f"The type of the criteria must be 'CriteriaWithOptions', instead it is of type '{type(self.criteria)}'"
|
276 |
)
|
277 |
-
criterias:
|
278 |
unique_criterias = list({criteria.name for criteria in criterias})
|
279 |
self.logger.info(f"Criteria names are '{', '.join(unique_criterias)}'")
|
280 |
return criterias
|
@@ -289,8 +289,8 @@ class LLMJudgeDirect(LLMJudge):
|
|
289 |
option_selection_outputs,
|
290 |
selections,
|
291 |
evaluations_count,
|
292 |
-
criterias:
|
293 |
-
) ->
|
294 |
positional_bias = None
|
295 |
if self.check_positional_bias:
|
296 |
positional_bias = [
|
@@ -353,9 +353,9 @@ class LLMJudgeDirect(LLMJudge):
|
|
353 |
|
354 |
def compute(
|
355 |
self,
|
356 |
-
references:
|
357 |
-
predictions:
|
358 |
-
task_data:
|
359 |
) -> dict:
|
360 |
self.logger.info(
|
361 |
f'Starting evaluation with evaluator "{self.evaluator_name}" and provider "{self.inference_engine.get_pretty_print_name()}'
|
@@ -545,7 +545,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
545 |
f"The type of the criteria must be 'Criteria', instead it is of type '{type(self.criteria)}'"
|
546 |
)
|
547 |
|
548 |
-
criterias:
|
549 |
|
550 |
unique_criterias = list({criteria.name for criteria in criterias})
|
551 |
self.logger.info(f"Criteria names are '{', '.join(unique_criterias)}'")
|
@@ -553,7 +553,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
553 |
|
554 |
def get_instance_results(
|
555 |
self,
|
556 |
-
instance_predictions:
|
557 |
assessment_prompts,
|
558 |
assessment_outputs,
|
559 |
summarization_prompts,
|
@@ -728,7 +728,7 @@ class LLMJudgePairwise(LLMJudge):
|
|
728 |
all_results["criteria"] = criteria.to_json()
|
729 |
return self.clean_results(all_results)
|
730 |
|
731 |
-
def parse_prediction_to_dict(self, prediction: Union[
|
732 |
if isinstance(prediction, list):
|
733 |
return {f"{key + 1}": value for key, value in enumerate(prediction)}
|
734 |
|
@@ -740,15 +740,15 @@ class LLMJudgePairwise(LLMJudge):
|
|
740 |
)
|
741 |
|
742 |
def convert_predictions_to_dicts(
|
743 |
-
self, predictions: Union[
|
744 |
):
|
745 |
return [self.parse_prediction_to_dict(prediction) for prediction in predictions]
|
746 |
|
747 |
def compute(
|
748 |
self,
|
749 |
-
references:
|
750 |
-
predictions: Union[
|
751 |
-
task_data:
|
752 |
) -> dict:
|
753 |
self.logger.info(
|
754 |
f'Starting evaluation with evaluator "{self.evaluator_name}" and provider {self.inference_engine.get_pretty_print_name()}'
|
@@ -775,8 +775,8 @@ class LLMJudgePairwise(LLMJudge):
|
|
775 |
f"The evaluation will perform {sum(contests_count_list) * [1,2][self.check_positional_bias]} ({' + '.join([f'{c * [1,2][self.check_positional_bias]}' for c in contests_count_list])}) pairwise comparisons"
|
776 |
)
|
777 |
|
778 |
-
response_pairs_list:
|
779 |
-
option_pairs_list:
|
780 |
predictions_names = set(predictions[0].keys())
|
781 |
for i, combination_indexes in enumerate(combination_indexes_list):
|
782 |
instance_predictions = predictions[i]
|
@@ -786,8 +786,8 @@ class LLMJudgePairwise(LLMJudge):
|
|
786 |
f"The set of prediction names is different between instance 0 and instance {i}. In prediction 0, it is {sorted(predictions_names)}. In prediction {i}, it is {sorted(instance_predictions_names)}. Make sure the same number of predictions is passed for all instances."
|
787 |
)
|
788 |
|
789 |
-
response_pairs:
|
790 |
-
option_pairs:
|
791 |
for combination in combination_indexes:
|
792 |
(idx_1, idx_2) = combination
|
793 |
response_name_1 = instance_predictions_names[idx_1]
|
|
|
1 |
import itertools
|
2 |
from difflib import get_close_matches
|
3 |
+
from typing import Any, Dict, List, Optional, Union
|
4 |
|
5 |
from .api import infer
|
6 |
from .artifact import fetch_artifact
|
|
|
145 |
)
|
146 |
return
|
147 |
|
148 |
+
def get_contexts(self, task_data: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
149 |
return [
|
150 |
get_parsed_context(
|
151 |
{
|
|
|
161 |
instances: list,
|
162 |
task: Task,
|
163 |
template: Template,
|
164 |
+
previous_messages: Optional[List[Dict[str, str]]] = None,
|
165 |
):
|
166 |
outputs_dataset = infer(
|
167 |
instances,
|
|
|
172 |
return_data=True,
|
173 |
previous_messages=previous_messages,
|
174 |
)
|
175 |
+
prompts: List[str] = [instance["source"] for instance in outputs_dataset]
|
176 |
+
raw_predictions: List[str] = [
|
177 |
instance["raw_prediction"] for instance in outputs_dataset
|
178 |
]
|
179 |
+
predictions: List[str] = [
|
180 |
instance["prediction"] for instance in outputs_dataset
|
181 |
]
|
182 |
return (prompts, raw_predictions, predictions)
|
|
|
274 |
raise Exception(
|
275 |
f"The type of the criteria must be 'CriteriaWithOptions', instead it is of type '{type(self.criteria)}'"
|
276 |
)
|
277 |
+
criterias: List[CriteriaWithOptions] = [self.criteria] * eval_count
|
278 |
unique_criterias = list({criteria.name for criteria in criterias})
|
279 |
self.logger.info(f"Criteria names are '{', '.join(unique_criterias)}'")
|
280 |
return criterias
|
|
|
289 |
option_selection_outputs,
|
290 |
selections,
|
291 |
evaluations_count,
|
292 |
+
criterias: List[CriteriaWithOptions],
|
293 |
+
) -> List[Dict[str, Any]]:
|
294 |
positional_bias = None
|
295 |
if self.check_positional_bias:
|
296 |
positional_bias = [
|
|
|
353 |
|
354 |
def compute(
|
355 |
self,
|
356 |
+
references: List[List[str]],
|
357 |
+
predictions: List[str],
|
358 |
+
task_data: List[Dict[str, Any]],
|
359 |
) -> dict:
|
360 |
self.logger.info(
|
361 |
f'Starting evaluation with evaluator "{self.evaluator_name}" and provider "{self.inference_engine.get_pretty_print_name()}'
|
|
|
545 |
f"The type of the criteria must be 'Criteria', instead it is of type '{type(self.criteria)}'"
|
546 |
)
|
547 |
|
548 |
+
criterias: List[Criteria] = [self.criteria] * eval_count
|
549 |
|
550 |
unique_criterias = list({criteria.name for criteria in criterias})
|
551 |
self.logger.info(f"Criteria names are '{', '.join(unique_criterias)}'")
|
|
|
553 |
|
554 |
def get_instance_results(
|
555 |
self,
|
556 |
+
instance_predictions: Dict[str, str],
|
557 |
assessment_prompts,
|
558 |
assessment_outputs,
|
559 |
summarization_prompts,
|
|
|
728 |
all_results["criteria"] = criteria.to_json()
|
729 |
return self.clean_results(all_results)
|
730 |
|
731 |
+
def parse_prediction_to_dict(self, prediction: Union[Dict[str, str], List[str]]):
|
732 |
if isinstance(prediction, list):
|
733 |
return {f"{key + 1}": value for key, value in enumerate(prediction)}
|
734 |
|
|
|
740 |
)
|
741 |
|
742 |
def convert_predictions_to_dicts(
|
743 |
+
self, predictions: Union[List[Dict[str, str]], List[str]]
|
744 |
):
|
745 |
return [self.parse_prediction_to_dict(prediction) for prediction in predictions]
|
746 |
|
747 |
def compute(
|
748 |
self,
|
749 |
+
references: List[List[str]],
|
750 |
+
predictions: Union[List[Dict[str, str]], List[str]],
|
751 |
+
task_data: List[Dict[str, str]],
|
752 |
) -> dict:
|
753 |
self.logger.info(
|
754 |
f'Starting evaluation with evaluator "{self.evaluator_name}" and provider {self.inference_engine.get_pretty_print_name()}'
|
|
|
775 |
f"The evaluation will perform {sum(contests_count_list) * [1,2][self.check_positional_bias]} ({' + '.join([f'{c * [1,2][self.check_positional_bias]}' for c in contests_count_list])}) pairwise comparisons"
|
776 |
)
|
777 |
|
778 |
+
response_pairs_list: List[List[List[str]]] = []
|
779 |
+
option_pairs_list: List[List[List[str]]] = []
|
780 |
predictions_names = set(predictions[0].keys())
|
781 |
for i, combination_indexes in enumerate(combination_indexes_list):
|
782 |
instance_predictions = predictions[i]
|
|
|
786 |
f"The set of prediction names is different between instance 0 and instance {i}. In prediction 0, it is {sorted(predictions_names)}. In prediction {i}, it is {sorted(instance_predictions_names)}. Make sure the same number of predictions is passed for all instances."
|
787 |
)
|
788 |
|
789 |
+
response_pairs: List[List[str]] = []
|
790 |
+
option_pairs: List[List[str]] = []
|
791 |
for combination in combination_indexes:
|
792 |
(idx_1, idx_2) = combination
|
793 |
response_name_1 = instance_predictions_names[idx_1]
|
llm_as_judge_constants.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import json
|
2 |
from enum import Enum
|
3 |
-
from typing import Optional
|
4 |
|
5 |
from .artifact import Artifact
|
6 |
from .inference import (
|
@@ -36,15 +36,15 @@ class Criteria(Artifact):
|
|
36 |
|
37 |
|
38 |
class CriteriaWithOptions(Criteria):
|
39 |
-
options:
|
40 |
-
option_map: Optional[
|
41 |
|
42 |
@staticmethod
|
43 |
def from_jsons(s: str):
|
44 |
return CriteriaWithOptions.from_obj(json.loads(s))
|
45 |
|
46 |
@staticmethod
|
47 |
-
def from_obj(criteria_dict:
|
48 |
return CriteriaWithOptions(
|
49 |
name=criteria_dict["name"],
|
50 |
description=criteria_dict["description"],
|
@@ -132,7 +132,7 @@ PROVIDER_TO_STRATEGY = {
|
|
132 |
|
133 |
class EvaluatorMetadata:
|
134 |
name: EvaluatorNameEnum
|
135 |
-
providers:
|
136 |
|
137 |
def __init__(self, name, providers):
|
138 |
self.name = name
|
|
|
1 |
import json
|
2 |
from enum import Enum
|
3 |
+
from typing import Dict, List, Optional
|
4 |
|
5 |
from .artifact import Artifact
|
6 |
from .inference import (
|
|
|
36 |
|
37 |
|
38 |
class CriteriaWithOptions(Criteria):
|
39 |
+
options: List[CriteriaOption]
|
40 |
+
option_map: Optional[Dict[str, float]] = None
|
41 |
|
42 |
@staticmethod
|
43 |
def from_jsons(s: str):
|
44 |
return CriteriaWithOptions.from_obj(json.loads(s))
|
45 |
|
46 |
@staticmethod
|
47 |
+
def from_obj(criteria_dict: Dict):
|
48 |
return CriteriaWithOptions(
|
49 |
name=criteria_dict["name"],
|
50 |
description=criteria_dict["description"],
|
|
|
132 |
|
133 |
class EvaluatorMetadata:
|
134 |
name: EvaluatorNameEnum
|
135 |
+
providers: List[ModelProviderEnum]
|
136 |
|
137 |
def __init__(self, name, providers):
|
138 |
self.name = name
|
llm_as_judge_utils.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from .llm_as_judge_constants import (
|
2 |
EVALUATORS_METADATA,
|
3 |
MODEL_RENAMINGS,
|
@@ -7,7 +9,7 @@ from .llm_as_judge_constants import (
|
|
7 |
)
|
8 |
|
9 |
|
10 |
-
def get_parsed_context(context:
|
11 |
return (
|
12 |
"\n".join([f"{key}: {value}" for key, value in context.items()])
|
13 |
if len(context) > 1
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
from .llm_as_judge_constants import (
|
4 |
EVALUATORS_METADATA,
|
5 |
MODEL_RENAMINGS,
|
|
|
9 |
)
|
10 |
|
11 |
|
12 |
+
def get_parsed_context(context: Dict[str, str]):
|
13 |
return (
|
14 |
"\n".join([f"{key}: {value}" for key, value in context.items()])
|
15 |
if len(context) > 1
|
loaders.py
CHANGED
@@ -41,6 +41,7 @@ from tempfile import TemporaryDirectory
|
|
41 |
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
|
42 |
|
43 |
import pandas as pd
|
|
|
44 |
from datasets import load_dataset as hf_load_dataset
|
45 |
from huggingface_hub import HfApi
|
46 |
from tqdm import tqdm
|
@@ -51,7 +52,7 @@ from .logging_utils import get_logger
|
|
51 |
from .operator import SourceOperator
|
52 |
from .operators import Set
|
53 |
from .settings_utils import get_settings
|
54 |
-
from .stream import
|
55 |
from .type_utils import isoftype
|
56 |
from .utils import LRUCache
|
57 |
|
@@ -122,7 +123,7 @@ class Loader(SourceOperator):
|
|
122 |
)
|
123 |
return operator(multi_stream)
|
124 |
|
125 |
-
def
|
126 |
self, default_data_classification_policy, additional_info
|
127 |
):
|
128 |
if self.data_classification_policy is None:
|
@@ -162,23 +163,24 @@ class LoadHF(Loader):
|
|
162 |
and it can filter datasets upon loading.
|
163 |
|
164 |
Args:
|
165 |
-
path:
|
166 |
-
|
167 |
-
name:
|
168 |
-
|
169 |
-
data_dir:
|
170 |
-
|
171 |
-
split:
|
172 |
-
|
173 |
-
data_files:
|
174 |
-
|
175 |
-
revision:
|
176 |
-
|
177 |
-
streaming (bool):
|
178 |
-
|
179 |
-
filtering_lambda
|
180 |
-
|
181 |
-
num_proc (int):
|
|
|
182 |
|
183 |
Example:
|
184 |
Loading glue's mrpc dataset
|
@@ -278,40 +280,22 @@ class LoadHF(Loader):
|
|
278 |
for split in dataset.keys():
|
279 |
dataset[split] = dataset[split].to_iterable_dataset()
|
280 |
else:
|
281 |
-
dataset = {self.split: dataset}
|
282 |
-
|
283 |
-
if self.filtering_lambda is not None:
|
284 |
-
dataset = self.filter_load(dataset)
|
285 |
|
286 |
return dataset
|
287 |
|
288 |
-
def split_limited_load(self, dataset, split_name):
|
289 |
-
yield from itertools.islice(dataset[split_name], self.get_limit())
|
290 |
-
|
291 |
-
def limited_load(self, dataset):
|
292 |
-
self.log_limited_loading()
|
293 |
-
return MultiStream(
|
294 |
-
{
|
295 |
-
name: DynamicStream(
|
296 |
-
generator=self.split_limited_load,
|
297 |
-
gen_kwargs={"dataset": dataset, "split_name": name},
|
298 |
-
)
|
299 |
-
for name in dataset.keys()
|
300 |
-
}
|
301 |
-
)
|
302 |
-
|
303 |
def _maybe_set_classification_policy(self):
|
304 |
if os.path.exists(self.path):
|
305 |
-
self.
|
306 |
["proprietary"], "when loading from local files"
|
307 |
)
|
308 |
else:
|
309 |
-
self.
|
310 |
["public"],
|
311 |
None, # No warning when loading from public hub
|
312 |
)
|
313 |
|
314 |
-
def load_iterables(self):
|
315 |
try:
|
316 |
dataset = self.stream_dataset()
|
317 |
except (
|
@@ -319,8 +303,15 @@ class LoadHF(Loader):
|
|
319 |
): # streaming is not supported for zipped files so we load without streaming
|
320 |
dataset = self.load_dataset()
|
321 |
|
|
|
|
|
|
|
322 |
if self.get_limit() is not None:
|
323 |
-
|
|
|
|
|
|
|
|
|
324 |
|
325 |
return dataset
|
326 |
|
@@ -352,7 +343,7 @@ class LoadCSV(Loader):
|
|
352 |
sep: str = ","
|
353 |
|
354 |
def _maybe_set_classification_policy(self):
|
355 |
-
self.
|
356 |
["proprietary"], "when loading from local files"
|
357 |
)
|
358 |
|
@@ -365,9 +356,7 @@ class LoadCSV(Loader):
|
|
365 |
file_path, nrows=self.get_limit(), sep=self.sep
|
366 |
).to_dict("records")
|
367 |
else:
|
368 |
-
iterables[split_name] = pd.read_csv(file_path
|
369 |
-
"records"
|
370 |
-
)
|
371 |
return iterables
|
372 |
|
373 |
|
@@ -475,14 +464,22 @@ class LoadFromIBMCloud(Loader):
|
|
475 |
3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
|
476 |
|
477 |
Args:
|
478 |
-
endpoint_url_env:
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
|
487 |
Example:
|
488 |
Loading from IBM Cloud
|
@@ -578,7 +575,7 @@ class LoadFromIBMCloud(Loader):
|
|
578 |
raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
|
579 |
|
580 |
def _maybe_set_classification_policy(self):
|
581 |
-
self.
|
582 |
["proprietary"], "when loading from IBM COS"
|
583 |
)
|
584 |
|
@@ -729,7 +726,7 @@ class LoadFromDictionary(Loader):
|
|
729 |
)
|
730 |
|
731 |
def _maybe_set_classification_policy(self):
|
732 |
-
self.
|
733 |
["proprietary"], "when loading from python dictionary"
|
734 |
)
|
735 |
|
@@ -744,25 +741,24 @@ class LoadFromHFSpace(LoadHF):
|
|
744 |
from the given space and then reads them as a HuggingFace Dataset.
|
745 |
|
746 |
Args:
|
747 |
-
space_name (str):
|
748 |
-
|
749 |
-
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]):
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
authentication when accessing the HuggingFace Space - if necessary.
|
766 |
|
767 |
Example:
|
768 |
Loading from a HuggingFace Space
|
@@ -910,7 +906,7 @@ class LoadFromHFSpace(LoadHF):
|
|
910 |
)
|
911 |
|
912 |
def _maybe_set_classification_policy(self):
|
913 |
-
self.
|
914 |
["public"], "when loading from Huggingface spaces"
|
915 |
)
|
916 |
|
|
|
41 |
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
|
42 |
|
43 |
import pandas as pd
|
44 |
+
from datasets import IterableDatasetDict
|
45 |
from datasets import load_dataset as hf_load_dataset
|
46 |
from huggingface_hub import HfApi
|
47 |
from tqdm import tqdm
|
|
|
52 |
from .operator import SourceOperator
|
53 |
from .operators import Set
|
54 |
from .settings_utils import get_settings
|
55 |
+
from .stream import MultiStream
|
56 |
from .type_utils import isoftype
|
57 |
from .utils import LRUCache
|
58 |
|
|
|
123 |
)
|
124 |
return operator(multi_stream)
|
125 |
|
126 |
+
def set_default_data_classification(
|
127 |
self, default_data_classification_policy, additional_info
|
128 |
):
|
129 |
if self.data_classification_policy is None:
|
|
|
163 |
and it can filter datasets upon loading.
|
164 |
|
165 |
Args:
|
166 |
+
path:
|
167 |
+
The path or identifier of the dataset on the HuggingFace Hub.
|
168 |
+
name:
|
169 |
+
An optional dataset name.
|
170 |
+
data_dir:
|
171 |
+
Optional directory to store downloaded data.
|
172 |
+
split:
|
173 |
+
Optional specification of which split to load.
|
174 |
+
data_files:
|
175 |
+
Optional specification of particular data files to load.
|
176 |
+
revision:
|
177 |
+
Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
|
178 |
+
streaming (bool):
|
179 |
+
indicating if streaming should be used.
|
180 |
+
filtering_lambda (str, optional):
|
181 |
+
A lambda function for filtering the data after loading.
|
182 |
+
num_proc (int, optional):
|
183 |
+
Specifies the number of processes to use for parallel dataset loading.
|
184 |
|
185 |
Example:
|
186 |
Loading glue's mrpc dataset
|
|
|
280 |
for split in dataset.keys():
|
281 |
dataset[split] = dataset[split].to_iterable_dataset()
|
282 |
else:
|
283 |
+
dataset = {self.split: dataset.to_iterable_dataset()}
|
|
|
|
|
|
|
284 |
|
285 |
return dataset
|
286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
def _maybe_set_classification_policy(self):
|
288 |
if os.path.exists(self.path):
|
289 |
+
self.set_default_data_classification(
|
290 |
["proprietary"], "when loading from local files"
|
291 |
)
|
292 |
else:
|
293 |
+
self.set_default_data_classification(
|
294 |
["public"],
|
295 |
None, # No warning when loading from public hub
|
296 |
)
|
297 |
|
298 |
+
def load_iterables(self) -> IterableDatasetDict:
|
299 |
try:
|
300 |
dataset = self.stream_dataset()
|
301 |
except (
|
|
|
303 |
): # streaming is not supported for zipped files so we load without streaming
|
304 |
dataset = self.load_dataset()
|
305 |
|
306 |
+
if self.filtering_lambda is not None:
|
307 |
+
dataset = self.filter_load(dataset)
|
308 |
+
|
309 |
if self.get_limit() is not None:
|
310 |
+
self.log_limited_loading()
|
311 |
+
return {
|
312 |
+
split_name: dataset[split_name].take(self.get_limit())
|
313 |
+
for split_name in dataset
|
314 |
+
}
|
315 |
|
316 |
return dataset
|
317 |
|
|
|
343 |
sep: str = ","
|
344 |
|
345 |
def _maybe_set_classification_policy(self):
|
346 |
+
self.set_default_data_classification(
|
347 |
["proprietary"], "when loading from local files"
|
348 |
)
|
349 |
|
|
|
356 |
file_path, nrows=self.get_limit(), sep=self.sep
|
357 |
).to_dict("records")
|
358 |
else:
|
359 |
+
iterables[split_name] = pd.read_csv(file_path).to_dict("records")
|
|
|
|
|
360 |
return iterables
|
361 |
|
362 |
|
|
|
464 |
3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
|
465 |
|
466 |
Args:
|
467 |
+
endpoint_url_env:
|
468 |
+
Environment variable name for the IBM Cloud endpoint URL.
|
469 |
+
aws_access_key_id_env:
|
470 |
+
Environment variable name for the AWS access key ID.
|
471 |
+
aws_secret_access_key_env:
|
472 |
+
Environment variable name for the AWS secret access key.
|
473 |
+
bucket_name:
|
474 |
+
Name of the S3 bucket from which to load data.
|
475 |
+
data_dir:
|
476 |
+
Optional directory path within the bucket.
|
477 |
+
data_files:
|
478 |
+
Union type allowing either a list of file names or a mapping of splits to file names.
|
479 |
+
data_field:
|
480 |
+
The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
|
481 |
+
caching (bool):
|
482 |
+
indicating if caching is enabled to avoid re-downloading data.
|
483 |
|
484 |
Example:
|
485 |
Loading from IBM Cloud
|
|
|
575 |
raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
|
576 |
|
577 |
def _maybe_set_classification_policy(self):
|
578 |
+
self.set_default_data_classification(
|
579 |
["proprietary"], "when loading from IBM COS"
|
580 |
)
|
581 |
|
|
|
726 |
)
|
727 |
|
728 |
def _maybe_set_classification_policy(self):
|
729 |
+
self.set_default_data_classification(
|
730 |
["proprietary"], "when loading from python dictionary"
|
731 |
)
|
732 |
|
|
|
741 |
from the given space and then reads them as a HuggingFace Dataset.
|
742 |
|
743 |
Args:
|
744 |
+
space_name (str):
|
745 |
+
Name of the HuggingFace Space to be accessed.
|
746 |
+
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]):
|
747 |
+
Relative paths to files within a given repository. If given as a mapping,
|
748 |
+
paths should be values, while keys should represent the type of respective files
|
749 |
+
(training, testing etc.).
|
750 |
+
path (str, optional):
|
751 |
+
Absolute path to a directory where data should be downloaded.
|
752 |
+
revision (str, optional):
|
753 |
+
ID of a Git branch or commit to be used. By default, it is set to None,
|
754 |
+
thus data is downloaded from the main branch of the accessed repository.
|
755 |
+
use_token (bool, optional):
|
756 |
+
Whether a token is used for authentication when accessing
|
757 |
+
the HuggingFace Space. If necessary, the token is read from the HuggingFace
|
758 |
+
config folder.
|
759 |
+
token_env (str, optional):
|
760 |
+
Key of an env variable which value will be used for
|
761 |
+
authentication when accessing the HuggingFace Space - if necessary.
|
|
|
762 |
|
763 |
Example:
|
764 |
Loading from a HuggingFace Space
|
|
|
906 |
)
|
907 |
|
908 |
def _maybe_set_classification_policy(self):
|
909 |
+
self.set_default_data_classification(
|
910 |
["public"], "when loading from Huggingface spaces"
|
911 |
)
|
912 |
|
metric_utils.py
CHANGED
@@ -353,13 +353,11 @@ UNITXT_METRIC_SCHEMA = Features(
|
|
353 |
class GlobalScores(dict):
|
354 |
"""GlobalScores is a dictionary-based class designed to handle and transform metric results into a structured format.
|
355 |
|
356 |
-
|
357 |
-
score (float):
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
to_df():
|
362 |
-
Transforms the dictionary of results into a pandas DataFrame with score_name as the index,
|
363 |
"""
|
364 |
|
365 |
@property
|
@@ -550,12 +548,11 @@ class GroupsScores(dict):
|
|
550 |
This class provides a property to summarize the scores and a custom
|
551 |
string representation for pretty-printing.
|
552 |
|
553 |
-
Attributes:
|
554 |
-
summary (property): A property to get a summary of the group scores.
|
555 |
"""
|
556 |
|
557 |
@property
|
558 |
def summary(self):
|
|
|
559 |
data = self
|
560 |
# Desired metric columns
|
561 |
metric_cols = [
|
|
|
353 |
class GlobalScores(dict):
|
354 |
"""GlobalScores is a dictionary-based class designed to handle and transform metric results into a structured format.
|
355 |
|
356 |
+
Args:
|
357 |
+
score (float):
|
358 |
+
The main score value.
|
359 |
+
score_name (str):
|
360 |
+
The name of the main score.
|
|
|
|
|
361 |
"""
|
362 |
|
363 |
@property
|
|
|
548 |
This class provides a property to summarize the scores and a custom
|
549 |
string representation for pretty-printing.
|
550 |
|
|
|
|
|
551 |
"""
|
552 |
|
553 |
@property
|
554 |
def summary(self):
|
555 |
+
"""A property to get a summary of the group scores."""
|
556 |
data = self
|
557 |
# Desired metric columns
|
558 |
metric_cols = [
|
metrics.py
CHANGED
@@ -48,7 +48,7 @@ from .random_utils import get_seed
|
|
48 |
from .settings_utils import get_settings
|
49 |
from .stream import MultiStream, Stream
|
50 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
51 |
-
from .utils import deep_copy
|
52 |
|
53 |
logger = get_logger()
|
54 |
settings = get_settings()
|
@@ -992,7 +992,17 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
992 |
reference_field: str = NonPositionalField(default="references")
|
993 |
prediction_field: str = NonPositionalField(default="prediction")
|
994 |
|
995 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
996 |
"""Ensure that group_mean reduction_map is properly formatted.
|
997 |
|
998 |
Example: Apply the variance (np.var) to group Accuracy instance scores. This class would be specified as follows:
|
@@ -1042,17 +1052,6 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1042 |
1 'How do I repair my engine?' 'paraphrase'
|
1043 |
2 'Why are ants eating my food?' 'original'
|
1044 |
"""
|
1045 |
-
# instances need to all have task_data field with field group_id
|
1046 |
-
assert all(
|
1047 |
-
"task_data" in instance for instance in instances
|
1048 |
-
), "each instance must have an task_data field"
|
1049 |
-
assert all(
|
1050 |
-
isinstance(instance["task_data"], dict) for instance in instances
|
1051 |
-
), "each instance must have an task_data field that is a dict"
|
1052 |
-
assert all(
|
1053 |
-
"group_id" in instance["task_data"] for instance in instances
|
1054 |
-
), "each instance task_data dict must have a key group_id"
|
1055 |
-
|
1056 |
# validate the reduction_map
|
1057 |
assert (
|
1058 |
"group_mean" in self.reduction_map
|
@@ -1081,16 +1080,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1081 |
if "score_fields" in fields:
|
1082 |
assert isinstance(fields["score_fields"], list)
|
1083 |
|
1084 |
-
# for aggregation functions that use the subgroup_column (expect a dict of lists), check that
|
1085 |
-
# this field exists
|
1086 |
-
if self.subgroup_column is not None:
|
1087 |
-
assert all(
|
1088 |
-
self.subgroup_column in instance["task_data"] for instance in instances
|
1089 |
-
), f"each instance task_data dict must have a key {self.subgroup_column}"
|
1090 |
-
|
1091 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1092 |
-
|
1093 |
-
global_score = {"num_of_instances": len(
|
1094 |
for reduction_type, reduction_params in self.reduction_map.items():
|
1095 |
assert (
|
1096 |
reduction_type in self.implemented_reductions
|
@@ -1103,15 +1095,15 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1103 |
aggregation_function = self.average_item_scores
|
1104 |
reduction_fields = list(set(reduction_params))
|
1105 |
# no group reduction, so resample instances individually
|
1106 |
-
scores_to_resample =
|
1107 |
elif reduction_type == "max":
|
1108 |
aggregation_function = self.max_item_scores
|
1109 |
reduction_fields = list(set(reduction_params))
|
1110 |
# no group reduction, so resample instances individually
|
1111 |
-
scores_to_resample =
|
1112 |
elif reduction_type == "group_mean":
|
1113 |
aggregation_function = self.average_item_scores
|
1114 |
-
self._validate_group_mean_reduction(
|
1115 |
reduction_fields = (
|
1116 |
[self.main_score]
|
1117 |
if "score_fields" not in reduction_params
|
@@ -1127,7 +1119,7 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1127 |
scores_to_resample,
|
1128 |
aggregation_function,
|
1129 |
) = self._set_up_group_mean_aggregation(
|
1130 |
-
|
1131 |
reduction_params,
|
1132 |
reduction_fields,
|
1133 |
)
|
@@ -1168,18 +1160,32 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1168 |
)
|
1169 |
global_score.update(confidence_interval)
|
1170 |
|
1171 |
-
for instance in
|
1172 |
self.update_and_adjust_global_score(instance, global_score)
|
1173 |
-
|
|
|
|
|
|
|
1174 |
|
1175 |
def compute_instance_scores(
|
1176 |
self, stream: Stream, stream_name: Optional[str] = None
|
1177 |
):
|
1178 |
-
|
1179 |
|
1180 |
for instance in stream:
|
1181 |
instance = self.verify_instance(instance)
|
1182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1183 |
task_data = instance["task_data"] if "task_data" in instance else {}
|
1184 |
|
1185 |
if self.reference_field == "references":
|
@@ -1214,9 +1220,18 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1214 |
instance_score, instance["score"]["instance"]
|
1215 |
)
|
1216 |
)
|
1217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1218 |
|
1219 |
-
|
|
|
|
|
1220 |
|
1221 |
def get_group_scores(
|
1222 |
self,
|
@@ -1228,12 +1243,16 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1228 |
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
1229 |
|
1230 |
Args:
|
1231 |
-
instances
|
1232 |
-
|
1233 |
-
|
|
|
|
|
|
|
1234 |
or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
|
1235 |
callable function returns a single score for the group
|
1236 |
-
prepend_score_prefix:
|
|
|
1237 |
if down the stream such a prepending is expected.
|
1238 |
|
1239 |
Returns:
|
@@ -4910,14 +4929,18 @@ class IsCodeMixed(BulkInstanceMetric):
|
|
4910 |
class MetricsEnsemble(InstanceMetric, ArtifactFetcherMixin):
|
4911 |
"""Metrics Ensemble class for creating ensemble of given metrics.
|
4912 |
|
4913 |
-
|
4914 |
-
main_score (str):
|
4915 |
-
|
4916 |
-
|
4917 |
-
|
4918 |
-
|
4919 |
-
|
4920 |
-
|
|
|
|
|
|
|
|
|
4921 |
|
4922 |
"""
|
4923 |
|
|
|
48 |
from .settings_utils import get_settings
|
49 |
from .stream import MultiStream, Stream
|
50 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
51 |
+
from .utils import deep_copy, recursive_copy
|
52 |
|
53 |
logger = get_logger()
|
54 |
settings = get_settings()
|
|
|
992 |
reference_field: str = NonPositionalField(default="references")
|
993 |
prediction_field: str = NonPositionalField(default="prediction")
|
994 |
|
995 |
+
def _validate_group_mean_task_data(self, instance):
|
996 |
+
# instances need to all have task_data field with field group_id
|
997 |
+
assert "task_data" in instance, "each instance must have an task_data field"
|
998 |
+
assert isinstance(
|
999 |
+
instance["task_data"], dict
|
1000 |
+
), "each instance must have an task_data field that is a dict"
|
1001 |
+
assert (
|
1002 |
+
"group_id" in instance["task_data"]
|
1003 |
+
), "each instance task_data dict must have a key group_id"
|
1004 |
+
|
1005 |
+
def _validate_group_mean_reduction(self):
|
1006 |
"""Ensure that group_mean reduction_map is properly formatted.
|
1007 |
|
1008 |
Example: Apply the variance (np.var) to group Accuracy instance scores. This class would be specified as follows:
|
|
|
1052 |
1 'How do I repair my engine?' 'paraphrase'
|
1053 |
2 'Why are ants eating my food?' 'original'
|
1054 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1055 |
# validate the reduction_map
|
1056 |
assert (
|
1057 |
"group_mean" in self.reduction_map
|
|
|
1080 |
if "score_fields" in fields:
|
1081 |
assert isinstance(fields["score_fields"], list)
|
1082 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1083 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1084 |
+
instance_scores = self.compute_instance_scores(stream)
|
1085 |
+
global_score = {"num_of_instances": len(instance_scores)}
|
1086 |
for reduction_type, reduction_params in self.reduction_map.items():
|
1087 |
assert (
|
1088 |
reduction_type in self.implemented_reductions
|
|
|
1095 |
aggregation_function = self.average_item_scores
|
1096 |
reduction_fields = list(set(reduction_params))
|
1097 |
# no group reduction, so resample instances individually
|
1098 |
+
scores_to_resample = instance_scores
|
1099 |
elif reduction_type == "max":
|
1100 |
aggregation_function = self.max_item_scores
|
1101 |
reduction_fields = list(set(reduction_params))
|
1102 |
# no group reduction, so resample instances individually
|
1103 |
+
scores_to_resample = instance_scores
|
1104 |
elif reduction_type == "group_mean":
|
1105 |
aggregation_function = self.average_item_scores
|
1106 |
+
self._validate_group_mean_reduction()
|
1107 |
reduction_fields = (
|
1108 |
[self.main_score]
|
1109 |
if "score_fields" not in reduction_params
|
|
|
1119 |
scores_to_resample,
|
1120 |
aggregation_function,
|
1121 |
) = self._set_up_group_mean_aggregation(
|
1122 |
+
instance_scores,
|
1123 |
reduction_params,
|
1124 |
reduction_fields,
|
1125 |
)
|
|
|
1160 |
)
|
1161 |
global_score.update(confidence_interval)
|
1162 |
|
1163 |
+
for instance in instance_scores:
|
1164 |
self.update_and_adjust_global_score(instance, global_score)
|
1165 |
+
|
1166 |
+
for i, instance in enumerate(stream):
|
1167 |
+
instance["score"] = recursive_copy(instance_scores[i]["score"])
|
1168 |
+
yield instance
|
1169 |
|
1170 |
def compute_instance_scores(
|
1171 |
self, stream: Stream, stream_name: Optional[str] = None
|
1172 |
):
|
1173 |
+
instance_scores = []
|
1174 |
|
1175 |
for instance in stream:
|
1176 |
instance = self.verify_instance(instance)
|
1177 |
|
1178 |
+
if "group_mean" in self.reduction_map:
|
1179 |
+
self._validate_group_mean_task_data(instance)
|
1180 |
+
|
1181 |
+
# for aggregation functions that use the subgroup_column (expect a dict of lists), check that
|
1182 |
+
# this field exists
|
1183 |
+
if self.subgroup_column is not None:
|
1184 |
+
assert (
|
1185 |
+
"task_data" in instance
|
1186 |
+
and self.subgroup_column in instance["task_data"]
|
1187 |
+
), f"each instance task_data dict must have a key {self.subgroup_column}"
|
1188 |
+
|
1189 |
task_data = instance["task_data"] if "task_data" in instance else {}
|
1190 |
|
1191 |
if self.reference_field == "references":
|
|
|
1220 |
instance_score, instance["score"]["instance"]
|
1221 |
)
|
1222 |
)
|
1223 |
+
task_data = {}
|
1224 |
+
if "task_data" in instance:
|
1225 |
+
if "group_id" in instance["task_data"]:
|
1226 |
+
task_data["group_id"] = instance["task_data"]["group_id"]
|
1227 |
+
if self.subgroup_column in instance["task_data"]:
|
1228 |
+
task_data[self.subgroup_column] = instance["task_data"][
|
1229 |
+
self.subgroup_column
|
1230 |
+
]
|
1231 |
|
1232 |
+
instance_scores.append({"score": instance["score"], "task_data": task_data})
|
1233 |
+
|
1234 |
+
return instance_scores
|
1235 |
|
1236 |
def get_group_scores(
|
1237 |
self,
|
|
|
1243 |
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
1244 |
|
1245 |
Args:
|
1246 |
+
instances (list):
|
1247 |
+
List of observation instances with instance-level scores (fields) computed.
|
1248 |
+
score_names (list):
|
1249 |
+
List of instance score names in each instance to apply the aggregation function.
|
1250 |
+
group_aggregation_func (Callable):
|
1251 |
+
aggregation function accepting a list of numeric scores;
|
1252 |
or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
|
1253 |
callable function returns a single score for the group
|
1254 |
+
prepend_score_prefix (bool):
|
1255 |
+
if True - prepend the score_prefix to the score names in the returned dicts. Set to False
|
1256 |
if down the stream such a prepending is expected.
|
1257 |
|
1258 |
Returns:
|
|
|
4929 |
class MetricsEnsemble(InstanceMetric, ArtifactFetcherMixin):
|
4930 |
"""Metrics Ensemble class for creating ensemble of given metrics.
|
4931 |
|
4932 |
+
Args:
|
4933 |
+
main_score (str):
|
4934 |
+
The main score label used for evaluation.
|
4935 |
+
metrics (List[Union[Metric, str]]):
|
4936 |
+
List of metrics that will be ensemble.
|
4937 |
+
weights (List[float]):
|
4938 |
+
Weight of each the metrics
|
4939 |
+
reduction_map (Dict[str, List[str]]):
|
4940 |
+
Specifies the redaction method of the global score.
|
4941 |
+
InstanceMetric currently allows two reductions
|
4942 |
+
(see it definition at InstanceMetric class).
|
4943 |
+
This class define its default value to reduce by the mean of the main score.
|
4944 |
|
4945 |
"""
|
4946 |
|
operator.py
CHANGED
@@ -222,11 +222,11 @@ class SourceOperator(MultiStreamOperator):
|
|
222 |
|
223 |
A source operator is responsible for generating the data stream from some source, such as a database or a file.
|
224 |
This is the starting point of a stream processing pipeline.
|
225 |
-
The
|
226 |
that generates an output stream but does not take any input streams.
|
227 |
|
228 |
-
When called, a
|
229 |
-
to generate the required
|
230 |
|
231 |
"""
|
232 |
|
@@ -247,9 +247,14 @@ class SourceOperator(MultiStreamOperator):
|
|
247 |
class StreamInitializerOperator(SourceOperator):
|
248 |
"""A class representing a stream initializer operator in the streaming system.
|
249 |
|
250 |
-
A stream initializer operator is a special type of
|
|
|
|
|
|
|
251 |
|
252 |
-
When called, a
|
|
|
|
|
253 |
|
254 |
"""
|
255 |
|
@@ -278,11 +283,12 @@ def instance_result(result_stream):
|
|
278 |
class StreamOperator(MultiStreamOperator):
|
279 |
"""A class representing a single-stream operator in the streaming system.
|
280 |
|
281 |
-
A single-stream operator is a type of
|
282 |
-
|
283 |
-
and applies the
|
284 |
-
|
285 |
-
|
|
|
286 |
|
287 |
"""
|
288 |
|
@@ -353,13 +359,15 @@ class SingleStreamOperator(StreamOperator):
|
|
353 |
class PagedStreamOperator(StreamOperator):
|
354 |
"""A class representing a paged-stream operator in the streaming system.
|
355 |
|
356 |
-
A paged-stream operator is a type of
|
357 |
-
in a
|
358 |
-
The
|
359 |
to be performed on each page.
|
360 |
|
361 |
Args:
|
362 |
-
page_size (int):
|
|
|
|
|
363 |
"""
|
364 |
|
365 |
page_size: int = 1000
|
@@ -393,7 +401,12 @@ class PagedStreamOperator(StreamOperator):
|
|
393 |
class SingleStreamReducer(StreamingOperator):
|
394 |
"""A class representing a single-stream reducer in the streaming system.
|
395 |
|
396 |
-
A single-stream reducer is a type of
|
|
|
|
|
|
|
|
|
|
|
397 |
"""
|
398 |
|
399 |
def __call__(self, multi_stream: Optional[MultiStream] = None) -> Dict[str, Any]:
|
@@ -412,7 +425,10 @@ class SingleStreamReducer(StreamingOperator):
|
|
412 |
class InstanceOperator(StreamOperator):
|
413 |
"""A class representing a stream instance operator in the streaming system.
|
414 |
|
415 |
-
A stream instance operator is a type of
|
|
|
|
|
|
|
416 |
"""
|
417 |
|
418 |
def _process_stream(
|
@@ -449,7 +465,8 @@ class InstanceOperator(StreamOperator):
|
|
449 |
class InstanceOperatorValidator(InstanceOperator):
|
450 |
"""A class representing a stream instance operator validator in the streaming system.
|
451 |
|
452 |
-
A stream instance operator validator is a type of
|
|
|
453 |
"""
|
454 |
|
455 |
@abstractmethod
|
|
|
222 |
|
223 |
A source operator is responsible for generating the data stream from some source, such as a database or a file.
|
224 |
This is the starting point of a stream processing pipeline.
|
225 |
+
The ``SourceOperator`` class is a type of ``MultiStreamOperator``, which is a special type of ``StreamingOperator``
|
226 |
that generates an output stream but does not take any input streams.
|
227 |
|
228 |
+
When called, a ``SourceOperator`` invokes its ``process`` method, which should be implemented by all subclasses
|
229 |
+
to generate the required ``MultiStream``.
|
230 |
|
231 |
"""
|
232 |
|
|
|
247 |
class StreamInitializerOperator(SourceOperator):
|
248 |
"""A class representing a stream initializer operator in the streaming system.
|
249 |
|
250 |
+
A stream initializer operator is a special type of ``SourceOperator`` that is capable
|
251 |
+
of taking parameters during the stream generation process.
|
252 |
+
This can be useful in situations where the stream generation process needs to be
|
253 |
+
customized or configured based on certain parameters.
|
254 |
|
255 |
+
When called, a ``StreamInitializerOperator`` invokes its ``process`` method, passing any supplied
|
256 |
+
arguments and keyword arguments. The ``process`` method should be implemented by all subclasses
|
257 |
+
to generate the required ``MultiStream`` based on the given arguments and keyword arguments.
|
258 |
|
259 |
"""
|
260 |
|
|
|
283 |
class StreamOperator(MultiStreamOperator):
|
284 |
"""A class representing a single-stream operator in the streaming system.
|
285 |
|
286 |
+
A single-stream operator is a type of ``MultiStreamOperator`` that operates on individual
|
287 |
+
``Stream`` objects within a ``MultiStream``. It iterates through each ``Stream`` in the ``MultiStream``
|
288 |
+
and applies the ``process`` method.
|
289 |
+
|
290 |
+
The ``process`` method should be implemented by subclasses to define the specific operations
|
291 |
+
to be performed on each ``Stream``.
|
292 |
|
293 |
"""
|
294 |
|
|
|
359 |
class PagedStreamOperator(StreamOperator):
|
360 |
"""A class representing a paged-stream operator in the streaming system.
|
361 |
|
362 |
+
A paged-stream operator is a type of ``StreamOperator`` that operates on a page of instances
|
363 |
+
in a ``Stream`` at a time, where a page is a subset of instances.
|
364 |
+
The ``process`` method should be implemented by subclasses to define the specific operations
|
365 |
to be performed on each page.
|
366 |
|
367 |
Args:
|
368 |
+
page_size (int):
|
369 |
+
The size of each page in the stream. Defaults to 1000.
|
370 |
+
|
371 |
"""
|
372 |
|
373 |
page_size: int = 1000
|
|
|
401 |
class SingleStreamReducer(StreamingOperator):
|
402 |
"""A class representing a single-stream reducer in the streaming system.
|
403 |
|
404 |
+
A single-stream reducer is a type of ``StreamingOperator`` that operates on individual
|
405 |
+
``Stream`` objects within a ``MultiStream`` and reduces each ``Stream`` to a single output value.
|
406 |
+
|
407 |
+
The ``process`` method should be implemented by subclasses to define the specific reduction operation
|
408 |
+
to be performed on each ``Stream``.
|
409 |
+
|
410 |
"""
|
411 |
|
412 |
def __call__(self, multi_stream: Optional[MultiStream] = None) -> Dict[str, Any]:
|
|
|
425 |
class InstanceOperator(StreamOperator):
|
426 |
"""A class representing a stream instance operator in the streaming system.
|
427 |
|
428 |
+
A stream instance operator is a type of ``StreamOperator`` that operates on individual instances
|
429 |
+
within a ``Stream``. It iterates through each instance in the ``Stream`` and applies the ``process`` method.
|
430 |
+
The ``process`` method should be implemented by subclasses to define the specific operations
|
431 |
+
to be performed on each instance.
|
432 |
"""
|
433 |
|
434 |
def _process_stream(
|
|
|
465 |
class InstanceOperatorValidator(InstanceOperator):
|
466 |
"""A class representing a stream instance operator validator in the streaming system.
|
467 |
|
468 |
+
A stream instance operator validator is a type of ``InstanceOperator`` that includes a validation step.
|
469 |
+
It operates on individual instances within a ``Stream`` and validates the result of processing each instance.
|
470 |
"""
|
471 |
|
472 |
@abstractmethod
|
operators.py
CHANGED
@@ -66,6 +66,7 @@ from .artifact import Artifact, fetch_artifact
|
|
66 |
from .dataclass import NonPositionalField, OptionalField
|
67 |
from .deprecation_utils import deprecation
|
68 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
|
|
69 |
from .operator import (
|
70 |
InstanceOperator,
|
71 |
MultiStream,
|
@@ -81,7 +82,7 @@ from .operator import (
|
|
81 |
)
|
82 |
from .random_utils import new_random_generator
|
83 |
from .settings_utils import get_settings
|
84 |
-
from .stream import DynamicStream,
|
85 |
from .text_utils import nested_tuple_to_string
|
86 |
from .type_utils import isoftype
|
87 |
from .utils import (
|
@@ -132,23 +133,24 @@ class IterableSource(SourceOperator):
|
|
132 |
class MapInstanceValues(InstanceOperator):
|
133 |
"""A class used to map instance values into other values.
|
134 |
|
135 |
-
This class is a type of InstanceOperator
|
136 |
it maps values of instances in a stream using predefined mappers.
|
137 |
|
138 |
-
|
139 |
-
mappers (Dict[str, Dict[str, Any]]):
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
strict (bool):
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
process_every_value (bool):
|
150 |
-
|
151 |
-
|
|
|
152 |
|
153 |
Examples:
|
154 |
``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
|
@@ -335,23 +337,23 @@ class InstanceFieldOperator(InstanceOperator):
|
|
335 |
"""A general stream instance operator that processes the values of a field (or multiple ones).
|
336 |
|
337 |
Args:
|
338 |
-
field (Optional[str]):
|
339 |
-
|
340 |
-
to_field (Optional[str]):
|
341 |
-
|
342 |
-
|
343 |
-
field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]):
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
|
356 |
Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
|
357 |
prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
|
@@ -806,10 +808,16 @@ class TakeByField(InstanceOperator):
|
|
806 |
|
807 |
|
808 |
class Perturb(FieldOperator):
|
809 |
-
"""Slightly perturbs the contents of
|
810 |
|
811 |
-
When task was classification, argument
|
812 |
relevant perturbation
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
"""
|
814 |
|
815 |
select_from: List[Any] = []
|
@@ -937,12 +945,13 @@ class CastFields(InstanceOperator):
|
|
937 |
"""Casts specified fields to specified types.
|
938 |
|
939 |
Args:
|
940 |
-
fields (Dict[str, str]):
|
941 |
-
|
942 |
-
|
943 |
-
defaults (Dict[str, object]):
|
944 |
-
|
945 |
-
process_every_value (bool):
|
|
|
946 |
|
947 |
Example:
|
948 |
.. code-block:: python
|
@@ -1268,16 +1277,19 @@ class FilterByExpression(StreamOperator, ComputeExpressionMixin):
|
|
1268 |
Raises an error if a field participating in the specified condition is missing from the instance
|
1269 |
|
1270 |
Args:
|
1271 |
-
|
1272 |
-
|
1273 |
-
|
|
|
|
|
|
|
1274 |
|
1275 |
Examples:
|
1276 |
-
|
1277 |
-
|
1278 |
-
|
1279 |
-
|
1280 |
-
|
1281 |
"""
|
1282 |
|
1283 |
error_on_filtered_all: bool = True
|
@@ -1635,23 +1647,17 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1635 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1636 |
from .metrics import Metric, MetricsList
|
1637 |
|
1638 |
-
|
1639 |
-
|
1640 |
-
|
1641 |
-
|
1642 |
-
|
1643 |
-
|
1644 |
-
# so that the evaluation of one does not affect the evaluation of another
|
1645 |
-
# (typically, affecting via change of instance as part of
|
1646 |
-
# preprocess_steps of MetricPipeline, as illustrated in docs/adding_metrics/Using Metric Pipelines).
|
1647 |
-
|
1648 |
-
instances_upon_entrance_to_metrics_evaluations = []
|
1649 |
-
for instance in stream:
|
1650 |
-
instances_upon_entrance_to_metrics_evaluations.append(
|
1651 |
-
recursive_copy(instance)
|
1652 |
-
)
|
1653 |
|
1654 |
-
|
|
|
|
|
|
|
1655 |
|
1656 |
metric_names = first_instance.get(self.metric_field, [])
|
1657 |
if not metric_names:
|
@@ -1680,26 +1686,28 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1680 |
# by the first listed metric (as desired).
|
1681 |
metrics_list = list(reversed(metrics_list))
|
1682 |
|
1683 |
-
for metric in metrics_list:
|
1684 |
if not self.calc_confidence_intervals:
|
1685 |
metric.disable_confidence_interval_calculation()
|
1686 |
-
|
1687 |
-
|
1688 |
-
|
1689 |
-
|
1690 |
-
|
1691 |
-
|
1692 |
-
}
|
1693 |
-
)
|
1694 |
-
multi_stream = metric(multi_stream)
|
1695 |
-
for evaluated_instance, freezed_instance in zip(
|
1696 |
-
multi_stream["tmp"], instances_upon_entrance_to_metrics_evaluations
|
1697 |
-
):
|
1698 |
-
freezed_instance["score"] = recursive_shallow_copy(
|
1699 |
-
evaluated_instance["score"]
|
1700 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1701 |
|
1702 |
-
yield from
|
1703 |
|
1704 |
|
1705 |
class MergeStreams(MultiStreamOperator):
|
@@ -1872,13 +1880,15 @@ class StreamRefiner(StreamOperator):
|
|
1872 |
input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
|
1873 |
of the leading 'max_instances' of the input stream.
|
1874 |
|
1875 |
-
Args:
|
1876 |
-
|
|
|
|
|
1877 |
|
1878 |
Examples:
|
1879 |
-
when input = [{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}] is fed into
|
1880 |
-
StreamRefiner(max_instances=4)
|
1881 |
-
the resulting stream is [{"a": 1},{"a": 2},{"a": 3},{"a": 4}]
|
1882 |
"""
|
1883 |
|
1884 |
max_instances: int = None
|
@@ -1899,18 +1909,20 @@ class DeterministicBalancer(StreamRefiner):
|
|
1899 |
When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
|
1900 |
'max_instances'. The total number of discarded instances is as few as possible.
|
1901 |
|
1902 |
-
|
1903 |
-
fields (List[str]):
|
1904 |
-
|
|
|
|
|
1905 |
|
1906 |
Usage:
|
1907 |
-
balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)
|
1908 |
-
balanced_stream = balancer.process(stream)
|
1909 |
|
1910 |
Example:
|
1911 |
-
When input [{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}] is fed into
|
1912 |
-
DeterministicBalancer(fields=["a"])
|
1913 |
-
the resulting stream will be: [{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]
|
1914 |
"""
|
1915 |
|
1916 |
fields: List[str]
|
@@ -1947,24 +1959,28 @@ class DeterministicBalancer(StreamRefiner):
|
|
1947 |
class MinimumOneExamplePerLabelRefiner(StreamRefiner):
|
1948 |
"""A class used to return a specified number instances ensuring at least one example per label.
|
1949 |
|
1950 |
-
For each instance, a signature value is constructed from the values of the instance in specified input
|
1951 |
-
MinimumOneExamplePerLabelRefiner takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit. In general, the refiner takes the first elements in the stream that meet the required conditions.
|
1952 |
-
MinimumOneExamplePerLabelRefiner then shuffles the results to avoid having one instance
|
1953 |
from each class first and then the rest . If max instance is not set, the original stream will be used
|
1954 |
|
1955 |
-
|
1956 |
-
fields (List[str]):
|
1957 |
-
|
|
|
|
|
|
|
|
|
1958 |
|
1959 |
Usage:
|
1960 |
-
balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)
|
1961 |
-
balanced_stream = balancer.process(stream)
|
1962 |
|
1963 |
Example:
|
1964 |
-
When input [{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}] is fed into
|
1965 |
-
MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)
|
1966 |
the resulting stream will be:
|
1967 |
-
[{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}] (order may be different)
|
1968 |
"""
|
1969 |
|
1970 |
fields: List[str]
|
@@ -2022,20 +2038,19 @@ class LengthBalancer(DeterministicBalancer):
|
|
2022 |
"""Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
|
2023 |
|
2024 |
Args:
|
2025 |
-
segments_boundaries (List[int]):
|
2026 |
-
|
2027 |
-
|
|
|
|
|
|
|
2028 |
|
2029 |
-
fields (Optional, List[str])
|
2030 |
|
2031 |
Example:
|
2032 |
-
when input [{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}]
|
2033 |
-
|
2034 |
-
|
2035 |
-
|
2036 |
-
LengthBalancer(fields=["a"], segments_boundaries=[1])
|
2037 |
-
|
2038 |
-
input instances will be counted and balanced against two categories: empty total length (less than 1), and non-empty.
|
2039 |
"""
|
2040 |
|
2041 |
segments_boundaries: List[int]
|
@@ -2067,9 +2082,11 @@ class UnexpectedHttpCodeError(Exception):
|
|
2067 |
class DownloadOperator(SideEffectOperator):
|
2068 |
"""Operator for downloading a file from a given URL to a specified local path.
|
2069 |
|
2070 |
-
|
2071 |
-
source (str):
|
2072 |
-
|
|
|
|
|
2073 |
"""
|
2074 |
|
2075 |
source: str
|
@@ -2089,9 +2106,11 @@ class DownloadOperator(SideEffectOperator):
|
|
2089 |
class ExtractZipFile(SideEffectOperator):
|
2090 |
"""Operator for extracting files from a zip archive.
|
2091 |
|
2092 |
-
|
2093 |
-
zip_file (str):
|
2094 |
-
|
|
|
|
|
2095 |
"""
|
2096 |
|
2097 |
zip_file: str
|
@@ -2105,8 +2124,9 @@ class ExtractZipFile(SideEffectOperator):
|
|
2105 |
class DuplicateInstances(StreamOperator):
|
2106 |
"""Operator which duplicates each instance in stream a given number of times.
|
2107 |
|
2108 |
-
|
2109 |
-
num_duplications (int):
|
|
|
2110 |
duplication_index_field (Optional[str]):
|
2111 |
If given, then additional field with specified name is added to each duplicated instance,
|
2112 |
which contains id of a given duplication. Defaults to None, so no field is added.
|
|
|
66 |
from .dataclass import NonPositionalField, OptionalField
|
67 |
from .deprecation_utils import deprecation
|
68 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
69 |
+
from .generator_utils import ReusableGenerator
|
70 |
from .operator import (
|
71 |
InstanceOperator,
|
72 |
MultiStream,
|
|
|
82 |
)
|
83 |
from .random_utils import new_random_generator
|
84 |
from .settings_utils import get_settings
|
85 |
+
from .stream import DynamicStream, Stream
|
86 |
from .text_utils import nested_tuple_to_string
|
87 |
from .type_utils import isoftype
|
88 |
from .utils import (
|
|
|
133 |
class MapInstanceValues(InstanceOperator):
|
134 |
"""A class used to map instance values into other values.
|
135 |
|
136 |
+
This class is a type of ``InstanceOperator``,
|
137 |
it maps values of instances in a stream using predefined mappers.
|
138 |
|
139 |
+
Args:
|
140 |
+
mappers (Dict[str, Dict[str, Any]]):
|
141 |
+
The mappers to use for mapping instance values.
|
142 |
+
Keys are the names of the fields to undergo mapping, and values are dictionaries
|
143 |
+
that define the mapping from old values to new values.
|
144 |
+
Note that mapped values are defined by their string representation, so mapped values
|
145 |
+
are converted to strings before being looked up in the mappers.
|
146 |
+
strict (bool):
|
147 |
+
If True, the mapping is applied strictly. That means if a value
|
148 |
+
does not exist in the mapper, it will raise a KeyError. If False, values
|
149 |
+
that are not present in the mapper are kept as they are.
|
150 |
+
process_every_value (bool):
|
151 |
+
If True, all fields to be mapped should be lists, and the mapping
|
152 |
+
is to be applied to their individual elements.
|
153 |
+
If False, mapping is only applied to a field containing a single value.
|
154 |
|
155 |
Examples:
|
156 |
``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
|
|
|
337 |
"""A general stream instance operator that processes the values of a field (or multiple ones).
|
338 |
|
339 |
Args:
|
340 |
+
field (Optional[str]):
|
341 |
+
The field to process, if only a single one is passed. Defaults to None
|
342 |
+
to_field (Optional[str]):
|
343 |
+
Field name to save result into, if only one field is processed, if None is passed the
|
344 |
+
operation would happen in-place and its result would replace the value of ``field``. Defaults to None
|
345 |
+
field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]):
|
346 |
+
Mapping from names of fields to process,
|
347 |
+
to names of fields to save the results into. Inner List, if used, should be of length 2.
|
348 |
+
A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
|
349 |
+
is mapped to the field. When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
|
350 |
+
in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
|
351 |
+
order. The end result might depend on that order if either (1) two different fields are mapped to the same
|
352 |
+
to_field, or (2) a field shows both as a key and as a value in different mappings.
|
353 |
+
The operator throws an AssertionError in either of these cases. ``field_to_field``
|
354 |
+
defaults to None.
|
355 |
+
process_every_value (bool):
|
356 |
+
Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
|
357 |
|
358 |
Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
|
359 |
prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
|
|
|
808 |
|
809 |
|
810 |
class Perturb(FieldOperator):
|
811 |
+
"""Slightly perturbs the contents of ``field``. Could be Handy for imitating prediction from given target.
|
812 |
|
813 |
+
When task was classification, argument ``select_from`` can be used to list the other potential classes, as a
|
814 |
relevant perturbation
|
815 |
+
|
816 |
+
Args:
|
817 |
+
percentage_to_perturb (int):
|
818 |
+
the percentage of the instances for which to apply this perturbation. Defaults to 1 (1 percent)
|
819 |
+
select_from: List[Any]:
|
820 |
+
a list of values to select from, as a perturbation of the field's value. Defaults to [].
|
821 |
"""
|
822 |
|
823 |
select_from: List[Any] = []
|
|
|
945 |
"""Casts specified fields to specified types.
|
946 |
|
947 |
Args:
|
948 |
+
fields (Dict[str, str]):
|
949 |
+
A dictionary mapping field names to the names of the types to cast the fields to.
|
950 |
+
e.g: "int", "str", "float", "bool". Basic names of types
|
951 |
+
defaults (Dict[str, object]):
|
952 |
+
A dictionary mapping field names to default values for cases of casting failure.
|
953 |
+
process_every_value (bool):
|
954 |
+
If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
|
955 |
|
956 |
Example:
|
957 |
.. code-block:: python
|
|
|
1277 |
Raises an error if a field participating in the specified condition is missing from the instance
|
1278 |
|
1279 |
Args:
|
1280 |
+
expression (str):
|
1281 |
+
a condition over fields of the instance, to be processed by python's eval()
|
1282 |
+
imports_list (List[str]):
|
1283 |
+
names of imports needed for the eval of the query (e.g. 're', 'json')
|
1284 |
+
error_on_filtered_all (bool, optional):
|
1285 |
+
If True, raises an error if all instances are filtered out. Defaults to True.
|
1286 |
|
1287 |
Examples:
|
1288 |
+
| ``FilterByExpression(expression = "a > 4")`` will yield only instances where "a">4
|
1289 |
+
| ``FilterByExpression(expression = "a <= 4 and b > 5")`` will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
|
1290 |
+
| ``FilterByExpression(expression = "a in [4, 8]")`` will yield only instances where "a" is 4 or 8
|
1291 |
+
| ``FilterByExpression(expression = "a not in [4, 8]")`` will yield only instances where "a" is neither 4 nor 8
|
1292 |
+
| ``FilterByExpression(expression = "a['b'] not in [4, 8]")`` will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
|
1293 |
"""
|
1294 |
|
1295 |
error_on_filtered_all: bool = True
|
|
|
1647 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1648 |
from .metrics import Metric, MetricsList
|
1649 |
|
1650 |
+
def update_scores_of_stream_instances(
|
1651 |
+
stream: Stream, scores: List[dict]
|
1652 |
+
) -> Generator:
|
1653 |
+
for instance, score in zip(stream, scores):
|
1654 |
+
instance["score"] = recursive_copy(score)
|
1655 |
+
yield instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1656 |
|
1657 |
+
# to be populated only when two or more metrics
|
1658 |
+
accumulated_scores = []
|
1659 |
+
|
1660 |
+
first_instance = stream.peek()
|
1661 |
|
1662 |
metric_names = first_instance.get(self.metric_field, [])
|
1663 |
if not metric_names:
|
|
|
1686 |
# by the first listed metric (as desired).
|
1687 |
metrics_list = list(reversed(metrics_list))
|
1688 |
|
1689 |
+
for metric_no, metric in enumerate(metrics_list):
|
1690 |
if not self.calc_confidence_intervals:
|
1691 |
metric.disable_confidence_interval_calculation()
|
1692 |
+
|
1693 |
+
if metric_no > 0:
|
1694 |
+
# update input stream with accumulated scores
|
1695 |
+
reusable_generator = ReusableGenerator(
|
1696 |
+
generator=update_scores_of_stream_instances,
|
1697 |
+
gen_kwargs={"stream": stream, "scores": accumulated_scores},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1698 |
)
|
1699 |
+
multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
|
1700 |
+
else:
|
1701 |
+
multi_stream = MultiStream.from_iterables({"tmp": stream})
|
1702 |
+
multi_stream = metric(multi_stream)
|
1703 |
+
if metric_no < len(metrics_list) - 1:
|
1704 |
+
# not the last metric, so prepare for the next metric by
|
1705 |
+
# updating accumulated_scores
|
1706 |
+
accumulated_scores = []
|
1707 |
+
for inst in multi_stream["tmp"]:
|
1708 |
+
accumulated_scores.append(recursive_copy(inst["score"]))
|
1709 |
|
1710 |
+
yield from multi_stream["tmp"]
|
1711 |
|
1712 |
|
1713 |
class MergeStreams(MultiStreamOperator):
|
|
|
1880 |
input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
|
1881 |
of the leading 'max_instances' of the input stream.
|
1882 |
|
1883 |
+
Args:
|
1884 |
+
max_instances (int)
|
1885 |
+
apply_to_streams (optional, list(str)):
|
1886 |
+
names of streams to refine.
|
1887 |
|
1888 |
Examples:
|
1889 |
+
when input = ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}]`` is fed into
|
1890 |
+
``StreamRefiner(max_instances=4)``
|
1891 |
+
the resulting stream is ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4}]``
|
1892 |
"""
|
1893 |
|
1894 |
max_instances: int = None
|
|
|
1909 |
When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
|
1910 |
'max_instances'. The total number of discarded instances is as few as possible.
|
1911 |
|
1912 |
+
Args:
|
1913 |
+
fields (List[str]):
|
1914 |
+
A list of field names to be used in producing the instance's signature.
|
1915 |
+
max_instances (Optional, int):
|
1916 |
+
overall max.
|
1917 |
|
1918 |
Usage:
|
1919 |
+
``balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)``
|
1920 |
+
``balanced_stream = balancer.process(stream)``
|
1921 |
|
1922 |
Example:
|
1923 |
+
When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}]`` is fed into
|
1924 |
+
``DeterministicBalancer(fields=["a"])``
|
1925 |
+
the resulting stream will be: ``[{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]``
|
1926 |
"""
|
1927 |
|
1928 |
fields: List[str]
|
|
|
1959 |
class MinimumOneExamplePerLabelRefiner(StreamRefiner):
|
1960 |
"""A class used to return a specified number instances ensuring at least one example per label.
|
1961 |
|
1962 |
+
For each instance, a signature value is constructed from the values of the instance in specified input ``fields``.
|
1963 |
+
``MinimumOneExamplePerLabelRefiner`` takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit. In general, the refiner takes the first elements in the stream that meet the required conditions.
|
1964 |
+
``MinimumOneExamplePerLabelRefiner`` then shuffles the results to avoid having one instance
|
1965 |
from each class first and then the rest . If max instance is not set, the original stream will be used
|
1966 |
|
1967 |
+
Args:
|
1968 |
+
fields (List[str]):
|
1969 |
+
A list of field names to be used in producing the instance's signature.
|
1970 |
+
max_instances (Optional, int):
|
1971 |
+
Number of elements to select. Note that max_instances of StreamRefiners
|
1972 |
+
that are passed to the recipe (e.g. ``train_refiner``. ``test_refiner``) are overridden
|
1973 |
+
by the recipe parameters ( ``max_train_instances``, ``max_test_instances``)
|
1974 |
|
1975 |
Usage:
|
1976 |
+
| ``balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)``
|
1977 |
+
| ``balanced_stream = balancer.process(stream)``
|
1978 |
|
1979 |
Example:
|
1980 |
+
When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}]`` is fed into
|
1981 |
+
``MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)``
|
1982 |
the resulting stream will be:
|
1983 |
+
``[{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}]`` (order may be different)
|
1984 |
"""
|
1985 |
|
1986 |
fields: List[str]
|
|
|
2038 |
"""Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
|
2039 |
|
2040 |
Args:
|
2041 |
+
segments_boundaries (List[int]):
|
2042 |
+
distinct integers sorted in increasing order, that map a given total length
|
2043 |
+
into the index of the least of them that exceeds the given total length.
|
2044 |
+
(If none exceeds -- into one index beyond, namely, the length of segments_boundaries)
|
2045 |
+
fields (Optional, List[str]):
|
2046 |
+
the total length of the values of these fields goes through the quantization described above
|
2047 |
|
|
|
2048 |
|
2049 |
Example:
|
2050 |
+
when input ``[{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}]``
|
2051 |
+
is fed into ``LengthBalancer(fields=["a"], segments_boundaries=[1])``,
|
2052 |
+
input instances will be counted and balanced against two categories:
|
2053 |
+
empty total length (less than 1), and non-empty.
|
|
|
|
|
|
|
2054 |
"""
|
2055 |
|
2056 |
segments_boundaries: List[int]
|
|
|
2082 |
class DownloadOperator(SideEffectOperator):
|
2083 |
"""Operator for downloading a file from a given URL to a specified local path.
|
2084 |
|
2085 |
+
Args:
|
2086 |
+
source (str):
|
2087 |
+
URL of the file to be downloaded.
|
2088 |
+
target (str):
|
2089 |
+
Local path where the downloaded file should be saved.
|
2090 |
"""
|
2091 |
|
2092 |
source: str
|
|
|
2106 |
class ExtractZipFile(SideEffectOperator):
|
2107 |
"""Operator for extracting files from a zip archive.
|
2108 |
|
2109 |
+
Args:
|
2110 |
+
zip_file (str):
|
2111 |
+
Path of the zip file to be extracted.
|
2112 |
+
target_dir (str):
|
2113 |
+
Directory where the contents of the zip file will be extracted.
|
2114 |
"""
|
2115 |
|
2116 |
zip_file: str
|
|
|
2124 |
class DuplicateInstances(StreamOperator):
|
2125 |
"""Operator which duplicates each instance in stream a given number of times.
|
2126 |
|
2127 |
+
Args:
|
2128 |
+
num_duplications (int):
|
2129 |
+
How many times each instance should be duplicated (1 means no duplication).
|
2130 |
duplication_index_field (Optional[str]):
|
2131 |
If given, then additional field with specified name is added to each duplicated instance,
|
2132 |
which contains id of a given duplication. Defaults to None, so no field is added.
|
processors.py
CHANGED
@@ -132,6 +132,14 @@ class TakeFirstNonEmptyLine(FieldOperator):
|
|
132 |
return parts[0].strip()
|
133 |
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
class ConvertToBoolean(FieldOperator):
|
136 |
def process_value(self, text: Any) -> Any:
|
137 |
clean_instance = str(text).strip().lower()
|
@@ -157,6 +165,11 @@ class Lower(FieldOperator):
|
|
157 |
return text.lower()
|
158 |
|
159 |
|
|
|
|
|
|
|
|
|
|
|
160 |
@deprecation("2.0.0", alternative=Lower)
|
161 |
class LowerCase(Lower):
|
162 |
pass
|
|
|
132 |
return parts[0].strip()
|
133 |
|
134 |
|
135 |
+
class TakeLastNonEmptyLine(FieldOperator):
|
136 |
+
def process_value(self, text: Any) -> Any:
|
137 |
+
parts = str(text).strip().split("\n")
|
138 |
+
if len(parts) == 0:
|
139 |
+
return ""
|
140 |
+
return parts[-1].strip()
|
141 |
+
|
142 |
+
|
143 |
class ConvertToBoolean(FieldOperator):
|
144 |
def process_value(self, text: Any) -> Any:
|
145 |
clean_instance = str(text).strip().lower()
|
|
|
165 |
return text.lower()
|
166 |
|
167 |
|
168 |
+
class Upper(FieldOperator):
|
169 |
+
def process_value(self, text: Any) -> Any:
|
170 |
+
return str(text).upper()
|
171 |
+
|
172 |
+
|
173 |
@deprecation("2.0.0", alternative=Lower)
|
174 |
class LowerCase(Lower):
|
175 |
pass
|
schema.py
CHANGED
@@ -143,6 +143,9 @@ class FinalizeDataset(InstanceOperatorValidator):
|
|
143 |
)
|
144 |
|
145 |
task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"]
|
|
|
|
|
|
|
146 |
task_data["metadata"]["template"] = self.artifact_to_jsonable(
|
147 |
instance["recipe_metadata"]["template"]
|
148 |
)
|
|
|
143 |
)
|
144 |
|
145 |
task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"]
|
146 |
+
task_data["metadata"]["demos_pool_size"] = instance["recipe_metadata"][
|
147 |
+
"demos_pool_size"
|
148 |
+
]
|
149 |
task_data["metadata"]["template"] = self.artifact_to_jsonable(
|
150 |
instance["recipe_metadata"]["template"]
|
151 |
)
|
settings_utils.py
CHANGED
@@ -138,7 +138,7 @@ if Settings.is_uninitilized():
|
|
138 |
settings.max_log_message_size = (int, 100000)
|
139 |
settings.catalogs = None
|
140 |
settings.artifactories = None
|
141 |
-
settings.default_recipe = "
|
142 |
settings.default_verbosity = "info"
|
143 |
settings.use_eager_execution = False
|
144 |
settings.remote_metrics = []
|
@@ -186,6 +186,7 @@ if Constants.is_uninitilized():
|
|
186 |
constants.inference_stream = "__INFERENCE_STREAM__"
|
187 |
constants.instance_stream = "__INSTANCE_STREAM__"
|
188 |
constants.image_tag = "unitxt-img"
|
|
|
189 |
|
190 |
|
191 |
def get_settings() -> Settings:
|
|
|
138 |
settings.max_log_message_size = (int, 100000)
|
139 |
settings.catalogs = None
|
140 |
settings.artifactories = None
|
141 |
+
settings.default_recipe = "dataset_recipe"
|
142 |
settings.default_verbosity = "info"
|
143 |
settings.use_eager_execution = False
|
144 |
settings.remote_metrics = []
|
|
|
186 |
constants.inference_stream = "__INFERENCE_STREAM__"
|
187 |
constants.instance_stream = "__INSTANCE_STREAM__"
|
188 |
constants.image_tag = "unitxt-img"
|
189 |
+
constants.demos_pool_field = "_demos_pool_"
|
190 |
|
191 |
|
192 |
def get_settings() -> Settings:
|
span_lableing_operators.py
CHANGED
@@ -6,19 +6,18 @@ from .operator import InstanceOperator
|
|
6 |
class IobExtractor(InstanceOperator):
|
7 |
"""A class designed to extract entities from sequences of text using the Inside-Outside-Beginning (IOB) tagging convention. It identifies entities based on IOB tags and categorizes them into predefined labels such as Person, Organization, and Location.
|
8 |
|
9 |
-
|
10 |
-
labels (List[str]):
|
11 |
-
|
12 |
-
begin_labels (List[str]):
|
13 |
-
|
14 |
-
inside_labels (List[str]):
|
15 |
-
|
16 |
-
outside_label (str):
|
|
|
17 |
|
18 |
The extraction process identifies spans of text corresponding to entities and labels them according to their entity type. Each span is annotated with a start and end character offset, the entity text, and the corresponding label.
|
19 |
|
20 |
-
|
21 |
-
|
22 |
Example of instantiation and usage:
|
23 |
|
24 |
.. code-block:: python
|
|
|
6 |
class IobExtractor(InstanceOperator):
|
7 |
"""A class designed to extract entities from sequences of text using the Inside-Outside-Beginning (IOB) tagging convention. It identifies entities based on IOB tags and categorizes them into predefined labels such as Person, Organization, and Location.
|
8 |
|
9 |
+
Args:
|
10 |
+
labels (List[str]):
|
11 |
+
A list of entity type labels, e.g., ["Person", "Organization", "Location"].
|
12 |
+
begin_labels (List[str]):
|
13 |
+
A list of labels indicating the beginning of an entity, e.g., ["B-PER", "B-ORG", "B-LOC"].
|
14 |
+
inside_labels (List[str]):
|
15 |
+
A list of labels indicating the continuation of an entity, e.g., ["I-PER", "I-ORG", "I-LOC"].
|
16 |
+
outside_label (str):
|
17 |
+
The label indicating tokens outside of any entity, typically "O".
|
18 |
|
19 |
The extraction process identifies spans of text corresponding to entities and labels them according to their entity type. Each span is annotated with a start and end character offset, the entity text, and the corresponding label.
|
20 |
|
|
|
|
|
21 |
Example of instantiation and usage:
|
22 |
|
23 |
.. code-block:: python
|
splitters.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
import itertools
|
2 |
from abc import abstractmethod
|
3 |
from difflib import get_close_matches
|
4 |
-
from typing import Dict, List, Optional
|
5 |
|
6 |
from .artifact import Artifact
|
7 |
from .dict_utils import dict_get
|
8 |
-
from .operator import
|
9 |
from .random_utils import new_random_generator
|
10 |
from .split_utils import (
|
11 |
parse_random_mix_string,
|
@@ -14,7 +14,7 @@ from .split_utils import (
|
|
14 |
rename_split,
|
15 |
slice_streams,
|
16 |
)
|
17 |
-
from .stream import
|
18 |
from .type_utils import isoftype
|
19 |
from .utils import recursive_copy
|
20 |
|
@@ -118,14 +118,14 @@ class Sampler(Artifact):
|
|
118 |
def sample(
|
119 |
self,
|
120 |
sample_size: int,
|
121 |
-
instances_pool: List[Dict[str,
|
122 |
-
instance: Dict[str,
|
123 |
-
) -> List[Dict[str,
|
124 |
pass
|
125 |
|
126 |
def filter_source_by_instance(
|
127 |
-
self, instances_pool: List[Dict[str,
|
128 |
-
) -> List[Dict[str,
|
129 |
if "input_fields" not in instance:
|
130 |
raise ValueError(f"'input_fields' field is missing from '{instance}'.")
|
131 |
try:
|
@@ -336,10 +336,11 @@ class DiverseLabelsSampler(Sampler):
|
|
336 |
return result
|
337 |
|
338 |
|
339 |
-
class
|
340 |
-
|
341 |
to_field: str
|
342 |
sampler: Sampler
|
|
|
343 |
|
344 |
def prepare(self):
|
345 |
self.local_cache = None
|
@@ -350,40 +351,36 @@ class Sample(InstanceOperatorWithMultiStreamAccess):
|
|
350 |
pass
|
351 |
|
352 |
def process(
|
353 |
-
self, instance: Dict[str,
|
354 |
-
) -> Dict[str,
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}."
|
367 |
-
)
|
368 |
-
sampled_instances = self.sampler.sample(
|
369 |
-
sample_size=sample_size, instances_pool=source_stream, instance=instance
|
370 |
)
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
|
378 |
|
379 |
-
class ConstantSizeSample(
|
380 |
sample_size: int
|
381 |
|
382 |
def get_sample_size(self, instance) -> int:
|
383 |
return self.sample_size
|
384 |
|
385 |
|
386 |
-
class RandomSizeSample(
|
387 |
sample_sizes: List[int]
|
388 |
|
389 |
def get_sample_size(self, instance) -> int:
|
|
|
1 |
import itertools
|
2 |
from abc import abstractmethod
|
3 |
from difflib import get_close_matches
|
4 |
+
from typing import Any, Dict, List, Optional
|
5 |
|
6 |
from .artifact import Artifact
|
7 |
from .dict_utils import dict_get
|
8 |
+
from .operator import InstanceOperator, MultiStreamOperator
|
9 |
from .random_utils import new_random_generator
|
10 |
from .split_utils import (
|
11 |
parse_random_mix_string,
|
|
|
14 |
rename_split,
|
15 |
slice_streams,
|
16 |
)
|
17 |
+
from .stream import MultiStream
|
18 |
from .type_utils import isoftype
|
19 |
from .utils import recursive_copy
|
20 |
|
|
|
118 |
def sample(
|
119 |
self,
|
120 |
sample_size: int,
|
121 |
+
instances_pool: List[Dict[str, Any]],
|
122 |
+
instance: Dict[str, Any],
|
123 |
+
) -> List[Dict[str, Any]]:
|
124 |
pass
|
125 |
|
126 |
def filter_source_by_instance(
|
127 |
+
self, instances_pool: List[Dict[str, Any]], instance: Dict[str, Any]
|
128 |
+
) -> List[Dict[str, Any]]:
|
129 |
if "input_fields" not in instance:
|
130 |
raise ValueError(f"'input_fields' field is missing from '{instance}'.")
|
131 |
try:
|
|
|
336 |
return result
|
337 |
|
338 |
|
339 |
+
class AssignDemosToInstance(InstanceOperator):
|
340 |
+
from_field: str
|
341 |
to_field: str
|
342 |
sampler: Sampler
|
343 |
+
skip_demoed_instances: bool = False
|
344 |
|
345 |
def prepare(self):
|
346 |
self.local_cache = None
|
|
|
351 |
pass
|
352 |
|
353 |
def process(
|
354 |
+
self, instance: Dict[str, Any], multi_stream: MultiStream
|
355 |
+
) -> Dict[str, Any]:
|
356 |
+
if self.skip_demoed_instances and self.to_field in instance:
|
357 |
+
if self.from_field in instance:
|
358 |
+
instance.pop(self.from_field)
|
359 |
+
return instance
|
360 |
|
361 |
+
demos_pool = instance[self.from_field]
|
362 |
+
sample_size = self.get_sample_size(instance)
|
363 |
+
source_stream = self.sampler.filter_source_by_instance(demos_pool, instance)
|
364 |
+
if len(source_stream) < sample_size:
|
365 |
+
raise ValueError(
|
366 |
+
f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {sample_size}. Please consider increasing increasing the demos pool, for which you may need to increase loader_limit or employ a less strict stream filtering."
|
|
|
|
|
|
|
|
|
367 |
)
|
368 |
+
sampled_instances = self.sampler.sample(
|
369 |
+
sample_size=sample_size, instances_pool=source_stream, instance=instance
|
370 |
+
)
|
371 |
+
instance[self.to_field] = recursive_copy(sampled_instances)
|
372 |
+
instance.pop(self.from_field) # pop the field pointing to the demos_pool
|
373 |
+
return instance
|
374 |
|
375 |
|
376 |
+
class ConstantSizeSample(AssignDemosToInstance):
|
377 |
sample_size: int
|
378 |
|
379 |
def get_sample_size(self, instance) -> int:
|
380 |
return self.sample_size
|
381 |
|
382 |
|
383 |
+
class RandomSizeSample(AssignDemosToInstance):
|
384 |
sample_sizes: List[int]
|
385 |
|
386 |
def get_sample_size(self, instance) -> int:
|
standard.py
CHANGED
@@ -1,26 +1,35 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
|
3 |
from .artifact import fetch_artifact
|
4 |
from .augmentors import Augmentor, NullAugmentor
|
5 |
from .card import TaskCard
|
6 |
from .collections_operators import GetLength
|
7 |
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
|
|
|
8 |
from .error_utils import UnitxtError
|
9 |
from .formats import Format, SystemFormat
|
|
|
10 |
from .logging_utils import get_logger
|
11 |
-
from .operator import
|
|
|
|
|
|
|
|
|
|
|
12 |
from .operators import Set, StreamRefiner
|
13 |
-
from .recipe import Recipe
|
14 |
from .schema import FinalizeDataset
|
15 |
from .serializers import SingleTypeSerializer
|
16 |
from .settings_utils import get_constants, get_settings
|
17 |
-
from .splitters import ConstantSizeSample, RandomSizeSample, Sampler
|
18 |
from .stream import MultiStream
|
19 |
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
20 |
from .task import Task
|
21 |
from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList
|
22 |
from .type_utils import isoftype
|
23 |
-
from .utils import LRUCache
|
24 |
|
25 |
constants = get_constants()
|
26 |
settings = get_settings()
|
@@ -28,11 +37,205 @@ logger = get_logger()
|
|
28 |
|
29 |
|
30 |
# Used to give meaningful name to recipe steps
|
31 |
-
class CreateDemosPool(
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
class BaseRecipe(Recipe, SourceSequentialOperator):
|
36 |
# Base parameters
|
37 |
card: TaskCard = None
|
38 |
task: Task = None
|
@@ -59,14 +262,18 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
59 |
test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
|
60 |
|
61 |
demos_pool_size: int = None
|
|
|
62 |
num_demos: Optional[Union[int, List[int]]] = 0
|
63 |
demos_removed_from_data: bool = True
|
|
|
64 |
|
65 |
-
demos_pool_name: str = "demos_pool"
|
66 |
demos_taken_from: str = "train"
|
67 |
demos_field: str = "demos"
|
68 |
sampler: Sampler = None
|
69 |
|
|
|
|
|
|
|
70 |
augmentor: Union[Augmentor, List[Augmentor]] = OptionalField(default=None)
|
71 |
|
72 |
steps: List[StreamingOperator] = InternalField(default_factory=list)
|
@@ -101,11 +308,16 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
101 |
raise ValueError(
|
102 |
"When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers."
|
103 |
)
|
104 |
-
if self.demos_pool_size < self.max_demos_size:
|
105 |
raise ValueError(
|
106 |
-
f"num_demos (got: {self.max_demos_size}) should not exceed demos_pool_size (got: {self.demos_pool_size})"
|
107 |
)
|
108 |
-
if
|
|
|
|
|
|
|
|
|
|
|
109 |
raise ValueError(
|
110 |
f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}"
|
111 |
)
|
@@ -220,29 +432,21 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
220 |
self.loading,
|
221 |
self.metadata,
|
222 |
self.standardization,
|
223 |
-
self.processing,
|
224 |
]
|
225 |
|
226 |
self.inference = SequentialOperator()
|
227 |
|
228 |
-
self.inference.steps = [self.
|
229 |
|
230 |
def production_preprocess(self, task_instances):
|
231 |
ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
|
232 |
-
return list(self.
|
233 |
-
|
234 |
-
def production_demos_pool(self):
|
235 |
-
if self.use_demos:
|
236 |
-
demos_pool = self.__class__._demos_pool_cache.get(str(self), None)
|
237 |
-
if demos_pool is None:
|
238 |
-
demos_pool = list(self.inference_demos()[self.demos_pool_name])
|
239 |
-
self.__class__._demos_pool_cache[str(self)] = demos_pool
|
240 |
-
return demos_pool
|
241 |
-
return []
|
242 |
|
243 |
@property
|
244 |
def has_custom_demos_pool(self):
|
245 |
-
return self.demos_pool_size is not None and
|
|
|
|
|
246 |
|
247 |
@property
|
248 |
def use_demos(self):
|
@@ -251,13 +455,22 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
251 |
def produce(self, task_instances):
|
252 |
"""Use the recipe in production to produce model ready query from standard task instance."""
|
253 |
self.before_process_multi_stream()
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
if self.
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
return list(multi_stream[constants.inference_stream])
|
262 |
|
263 |
def reset(self):
|
@@ -321,15 +534,29 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
321 |
augmentor.set_fields(self.card.task.augmentable_inputs)
|
322 |
self.processing.steps.append(augmentor)
|
323 |
|
|
|
|
|
|
|
|
|
324 |
if self.has_custom_demos_pool:
|
325 |
-
self.
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
)
|
332 |
-
)
|
333 |
|
334 |
if self.use_demos:
|
335 |
if self.sampler is None:
|
@@ -346,28 +573,41 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
346 |
if isinstance(self.num_demos, int):
|
347 |
self.verbalization.steps.append(
|
348 |
ConstantSizeSample(
|
349 |
-
|
350 |
to_field=self.demos_field,
|
351 |
sampler=self.sampler,
|
352 |
sample_size=self.num_demos,
|
|
|
353 |
)
|
354 |
)
|
355 |
self.verbalization.steps.append(
|
356 |
-
Set(
|
|
|
|
|
|
|
|
|
|
|
357 |
)
|
358 |
|
359 |
elif isinstance(self.num_demos, list):
|
360 |
self.verbalization.steps.append(
|
361 |
RandomSizeSample(
|
362 |
-
|
363 |
to_field=self.demos_field,
|
364 |
sampler=self.sampler,
|
365 |
sample_sizes=self.num_demos,
|
|
|
366 |
)
|
367 |
)
|
368 |
self.verbalization.steps.append(
|
369 |
GetLength(field="demos", to_field="recipe_metadata/num_demos")
|
370 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
else:
|
372 |
raise ValueError("num_demos must be int or List[int]")
|
373 |
|
@@ -383,9 +623,15 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
383 |
template=self.template, demos_field=self.demos_field
|
384 |
)
|
385 |
)
|
|
|
386 |
else:
|
387 |
self.verbalization.steps.append(
|
388 |
-
Set(
|
|
|
|
|
|
|
|
|
|
|
389 |
)
|
390 |
if isinstance(self.template, list):
|
391 |
self.verbalization.steps.append(
|
@@ -409,15 +655,6 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
409 |
|
410 |
self.finalize.steps.append(FinalizeDataset(group_by=self.group_by))
|
411 |
|
412 |
-
def prepare(self):
|
413 |
-
if isinstance(self.template, TemplatesList):
|
414 |
-
self.template = self.template.items
|
415 |
-
self.reset_pipeline()
|
416 |
-
|
417 |
-
|
418 |
-
class StandardRecipeWithIndexes(BaseRecipe):
|
419 |
-
template_card_index: int = None
|
420 |
-
|
421 |
def prepare(self):
|
422 |
assert (
|
423 |
self.template_card_index is None or self.template is None
|
@@ -464,77 +701,41 @@ class StandardRecipeWithIndexes(BaseRecipe):
|
|
464 |
raise ValueError(
|
465 |
"No template was specified in the the 'template' or 'template_card_index' recipe arguments, and no default templates are defined the card or task"
|
466 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
|
468 |
-
|
|
|
|
|
|
|
|
|
469 |
|
|
|
|
|
|
|
470 |
|
471 |
-
class StandardRecipe(StandardRecipeWithIndexes):
|
472 |
-
"""This class represents a standard recipe for data processing and preparation.
|
473 |
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
|
478 |
-
Args:
|
479 |
-
card (TaskCard):
|
480 |
-
TaskCard object associated with the recipe.
|
481 |
-
template (Template, optional):
|
482 |
-
Template object to be used for the recipe.
|
483 |
-
system_prompt (SystemPrompt, optional):
|
484 |
-
SystemPrompt object to be used for the recipe.
|
485 |
-
loader_limit (int, optional):
|
486 |
-
Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
|
487 |
-
format (SystemFormat, optional):
|
488 |
-
SystemFormat object to be used for the recipe.
|
489 |
-
metrics (List[str]):
|
490 |
-
list of catalog metrics to use with this recipe.
|
491 |
-
postprocessors (List[str]):
|
492 |
-
list of catalog processors to apply at post processing. (Not recommended to use from here)
|
493 |
-
group_by (List[Union[str, List[str]]]):
|
494 |
-
list of task_data or metadata keys to group global scores by.
|
495 |
-
train_refiner (StreamRefiner, optional):
|
496 |
-
Train refiner to be used in the recipe.
|
497 |
-
max_train_instances (int, optional):
|
498 |
-
Maximum training instances for the refiner.
|
499 |
-
validation_refiner (StreamRefiner, optional):
|
500 |
-
Validation refiner to be used in the recipe.
|
501 |
-
max_validation_instances (int, optional):
|
502 |
-
Maximum validation instances for the refiner.
|
503 |
-
test_refiner (StreamRefiner, optional):
|
504 |
-
Test refiner to be used in the recipe.
|
505 |
-
max_test_instances (int, optional):
|
506 |
-
Maximum test instances for the refiner.
|
507 |
-
demos_pool_size (int, optional):
|
508 |
-
Size of the demos pool.
|
509 |
-
num_demos (int, optional):
|
510 |
-
Number of demos to be used.
|
511 |
-
demos_pool_name (str, optional):
|
512 |
-
Name of the demos pool. Default is "demos_pool".
|
513 |
-
demos_taken_from (str, optional):
|
514 |
-
Specifies from where the demos are taken. Default is "train".
|
515 |
-
demos_field (str, optional):
|
516 |
-
Field name for demos. Default is "demos".
|
517 |
-
demos_removed_from_data (bool, optional):
|
518 |
-
whether to remove the demos from the source data, Default is True
|
519 |
-
sampler (Sampler, optional):
|
520 |
-
The Sampler used to select the demonstrations when num_demos > 0.
|
521 |
-
steps (List[StreamingOperator], optional):
|
522 |
-
List of StreamingOperator objects to be used in the recipe.
|
523 |
-
augmentor (Augmentor) :
|
524 |
-
Augmentor to be used to pseudo randomly augment the source text
|
525 |
-
instruction_card_index (int, optional):
|
526 |
-
Index of instruction card to be used for preparing the recipe.
|
527 |
-
template_card_index (int, optional):
|
528 |
-
Index of template card to be used for preparing the recipe.
|
529 |
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
by arranging all the steps, refiners, and renderers in a sequential manner.
|
534 |
|
535 |
-
Raises:
|
536 |
-
AssertionError:
|
537 |
-
If both template and template_card_index are specified at the same time.
|
538 |
-
"""
|
539 |
|
|
|
|
|
540 |
pass
|
|
|
1 |
+
import itertools
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
from typing import Any, Dict, Generator, List, Optional, Union
|
5 |
|
6 |
from .artifact import fetch_artifact
|
7 |
from .augmentors import Augmentor, NullAugmentor
|
8 |
from .card import TaskCard
|
9 |
from .collections_operators import GetLength
|
10 |
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
|
11 |
+
from .deprecation_utils import deprecation
|
12 |
from .error_utils import UnitxtError
|
13 |
from .formats import Format, SystemFormat
|
14 |
+
from .generator_utils import ReusableGenerator
|
15 |
from .logging_utils import get_logger
|
16 |
+
from .operator import (
|
17 |
+
MultiStreamOperator,
|
18 |
+
SequentialOperator,
|
19 |
+
SourceSequentialOperator,
|
20 |
+
StreamingOperator,
|
21 |
+
)
|
22 |
from .operators import Set, StreamRefiner
|
|
|
23 |
from .schema import FinalizeDataset
|
24 |
from .serializers import SingleTypeSerializer
|
25 |
from .settings_utils import get_constants, get_settings
|
26 |
+
from .splitters import ConstantSizeSample, RandomSizeSample, Sampler
|
27 |
from .stream import MultiStream
|
28 |
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
29 |
from .task import Task
|
30 |
from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList
|
31 |
from .type_utils import isoftype
|
32 |
+
from .utils import LRUCache, recursive_copy
|
33 |
|
34 |
constants = get_constants()
|
35 |
settings = get_settings()
|
|
|
37 |
|
38 |
|
39 |
# Used to give meaningful name to recipe steps
|
40 |
+
class CreateDemosPool(MultiStreamOperator):
|
41 |
+
from_stream: str = None
|
42 |
+
demos_pool_size: int = None
|
43 |
+
demos_removed_from_data: bool = None
|
44 |
+
to_field: str = constants.demos_pool_field
|
45 |
+
|
46 |
+
# flake8: noqa: B007
|
47 |
+
def process(self, multi_stream: MultiStream) -> MultiStream:
|
48 |
+
# generate the demos_pool as a selection of demos_pool_size distinct instances
|
49 |
+
# (distinct by their "input_fields" field). The selection is taken from stream named from_stream.
|
50 |
+
# The selected instances are later treated as ordinary instances or not, depending on parameter
|
51 |
+
# demos_removed_from_data.
|
52 |
+
# The selection of instances is done from the first instances of the stream named from_stream.
|
53 |
+
# instances that are not distinct from previously selected demo instances, are kept aside, to be later
|
54 |
+
# treated like all the remaining instances of stream from_stream.
|
55 |
+
if self.from_stream not in multi_stream:
|
56 |
+
raise ValueError(
|
57 |
+
f"Input multi-stream is missing a stream named '{self.from_stream}' to take demo instances from for the demos_pool."
|
58 |
+
)
|
59 |
+
if (
|
60 |
+
self.demos_removed_from_data is not None
|
61 |
+
and self.demos_removed_from_data is True
|
62 |
+
and (self.demos_pool_size == sys.maxsize)
|
63 |
+
):
|
64 |
+
# going to consume the whole of input stream named self.from_stream for demo instances,
|
65 |
+
# and not let demos instances to behave as regular instances. so self.from_stream
|
66 |
+
# ends here its life as an input stream that is expected to reach the end of the recipe
|
67 |
+
if len(multi_stream) == 1:
|
68 |
+
raise ValueError(
|
69 |
+
f"The single input stream, '{self.from_stream}' is to be wholly consumed for generating demos, and no instance is left to use these demos."
|
70 |
+
)
|
71 |
+
from_stream = multi_stream[self.from_stream]
|
72 |
+
demos_pool = []
|
73 |
+
input_fields_of_demos_pool = []
|
74 |
+
not_selected_from_from_stream = []
|
75 |
+
for num_scanned, instance in enumerate(from_stream):
|
76 |
+
if "input_fields" not in instance:
|
77 |
+
raise ValueError(f"'input_fields' field is missing from '{instance}'.")
|
78 |
+
input_fields_signature = json.dumps(
|
79 |
+
instance["input_fields"], sort_keys=True
|
80 |
+
)
|
81 |
+
if input_fields_signature in input_fields_of_demos_pool:
|
82 |
+
not_selected_from_from_stream.append(instance)
|
83 |
+
continue
|
84 |
+
demos_pool.append(instance)
|
85 |
+
input_fields_of_demos_pool.append(input_fields_signature)
|
86 |
+
if len(demos_pool) >= self.demos_pool_size:
|
87 |
+
break
|
88 |
+
|
89 |
+
# for backward compatibility, do not throw exception here if demos pool is smaller than expected.
|
90 |
+
# Delay that for the event (if occurs) that Sample is not be able to sample num_demos demos.
|
91 |
+
|
92 |
+
# to avoid endless recursion in case of not demos_removed_from_data
|
93 |
+
demos_pool = recursive_copy(demos_pool)
|
94 |
+
|
95 |
+
set_demos_pool = Set(fields={self.to_field: demos_pool})
|
96 |
+
if (
|
97 |
+
self.demos_removed_from_data is not None
|
98 |
+
and self.demos_removed_from_data is False
|
99 |
+
):
|
100 |
+
# all input instances go out. No one is "killed" because selected as demo
|
101 |
+
return set_demos_pool(multi_stream)
|
102 |
+
|
103 |
+
if (
|
104 |
+
self.demos_removed_from_data is not None
|
105 |
+
and self.demos_removed_from_data is True
|
106 |
+
):
|
107 |
+
if self.demos_pool_size == sys.maxsize:
|
108 |
+
# consume the whole of input stream self.from_stream, just for demos, and do not
|
109 |
+
# take any of its instances to behave as a non-demo instance, i.e., a regular instance
|
110 |
+
# that consume the demos
|
111 |
+
out_ms = MultiStream(
|
112 |
+
{
|
113 |
+
stream_name: multi_stream[stream_name]
|
114 |
+
for stream_name in multi_stream
|
115 |
+
if stream_name != self.from_stream
|
116 |
+
}
|
117 |
+
)
|
118 |
+
return set_demos_pool(out_ms)
|
119 |
+
|
120 |
+
# self.demos_removed_from_data and not consume the whole of self.from_stream just for demos
|
121 |
+
def from_stream_generator(
|
122 |
+
first_layer: list, ms: MultiStream, stream_name: str, start: int
|
123 |
+
) -> Generator:
|
124 |
+
yield from first_layer
|
125 |
+
yield from itertools.islice(ms[stream_name], start, None)
|
126 |
+
|
127 |
+
new_streams = {}
|
128 |
+
for stream_name in multi_stream:
|
129 |
+
if stream_name == self.from_stream:
|
130 |
+
new_streams[stream_name] = ReusableGenerator(
|
131 |
+
generator=from_stream_generator,
|
132 |
+
gen_kwargs={
|
133 |
+
"first_layer": not_selected_from_from_stream,
|
134 |
+
"ms": multi_stream,
|
135 |
+
"stream_name": self.from_stream,
|
136 |
+
"start": num_scanned + 1,
|
137 |
+
},
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
new_streams[stream_name] = ReusableGenerator(
|
141 |
+
generator=from_stream_generator,
|
142 |
+
gen_kwargs={
|
143 |
+
"first_layer": [],
|
144 |
+
"ms": multi_stream,
|
145 |
+
"stream_name": stream_name,
|
146 |
+
"start": 0,
|
147 |
+
},
|
148 |
+
)
|
149 |
+
|
150 |
+
ms = MultiStream.from_generators(new_streams)
|
151 |
+
return set_demos_pool(ms)
|
152 |
+
|
153 |
+
|
154 |
+
class AddDemosPool(MultiStreamOperator):
|
155 |
+
demos_pool: List[Dict[str, Any]]
|
156 |
+
demos_pool_field_name: str = constants.demos_pool_field
|
157 |
+
|
158 |
+
def process(self, multi_stream: MultiStream) -> MultiStream:
|
159 |
+
set_demos_pool = Set(fields={self.demos_pool_field_name: self.demos_pool})
|
160 |
+
return set_demos_pool(multi_stream)
|
161 |
+
|
162 |
+
|
163 |
+
class DatasetRecipe(SourceSequentialOperator):
|
164 |
+
"""This class represents a standard recipe for data processing and preparation.
|
165 |
|
166 |
+
This class can be used to prepare a recipe.
|
167 |
+
with all necessary steps, refiners and renderers included. It allows to set various
|
168 |
+
parameters and steps in a sequential manner for preparing the recipe.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
card (TaskCard):
|
172 |
+
TaskCard object associated with the recipe.
|
173 |
+
template (Template, optional):
|
174 |
+
Template object to be used for the recipe.
|
175 |
+
system_prompt (SystemPrompt, optional):
|
176 |
+
SystemPrompt object to be used for the recipe.
|
177 |
+
loader_limit (int, optional):
|
178 |
+
Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
|
179 |
+
format (SystemFormat, optional):
|
180 |
+
SystemFormat object to be used for the recipe.
|
181 |
+
metrics (List[str]):
|
182 |
+
list of catalog metrics to use with this recipe.
|
183 |
+
postprocessors (List[str]):
|
184 |
+
list of catalog processors to apply at post processing. (Not recommended to use from here)
|
185 |
+
group_by (List[Union[str, List[str]]]):
|
186 |
+
list of task_data or metadata keys to group global scores by.
|
187 |
+
train_refiner (StreamRefiner, optional):
|
188 |
+
Train refiner to be used in the recipe.
|
189 |
+
max_train_instances (int, optional):
|
190 |
+
Maximum training instances for the refiner.
|
191 |
+
validation_refiner (StreamRefiner, optional):
|
192 |
+
Validation refiner to be used in the recipe.
|
193 |
+
max_validation_instances (int, optional):
|
194 |
+
Maximum validation instances for the refiner.
|
195 |
+
test_refiner (StreamRefiner, optional):
|
196 |
+
Test refiner to be used in the recipe.
|
197 |
+
max_test_instances (int, optional):
|
198 |
+
Maximum test instances for the refiner.
|
199 |
+
demos_pool_size (int, optional):
|
200 |
+
Size of the demos pool. -1 for taking the whole of stream 'demos_taken_from'.
|
201 |
+
demos_pool(List[Dict[str, Any]], optional):
|
202 |
+
a list of instances to make the demos_pool
|
203 |
+
num_demos (int, optional):
|
204 |
+
Number of demos to add to each instance, to become part of the source to be generated for this instance.
|
205 |
+
demos_taken_from (str, optional):
|
206 |
+
Specifies the stream from where the demos are taken. Default is "train".
|
207 |
+
demos_field (str, optional):
|
208 |
+
Field name for demos. Default is "demos".
|
209 |
+
The num_demos demos selected for an instance are stored in this field of that instance.
|
210 |
+
demos_pool_field_name (str, optional):
|
211 |
+
field name to maintain the demos_pool, until sampled from, in order to make the demos.
|
212 |
+
Defaults to constants.demos_pool_field.
|
213 |
+
demos_removed_from_data (bool, optional):
|
214 |
+
whether to remove the demos taken to demos_pool from the source data, Default is True
|
215 |
+
sampler (Sampler, optional):
|
216 |
+
The Sampler used to select the demonstrations when num_demos > 0.
|
217 |
+
skip_demoed_instances (bool, optional):
|
218 |
+
whether to skip pushing demos to an instance whose demos_field is
|
219 |
+
already populated. Defaults to False.
|
220 |
+
steps (List[StreamingOperator], optional):
|
221 |
+
List of StreamingOperator objects to be used in the recipe.
|
222 |
+
augmentor (Augmentor) :
|
223 |
+
Augmentor to be used to pseudo randomly augment the source text
|
224 |
+
instruction_card_index (int, optional):
|
225 |
+
Index of instruction card to be used for preparing the recipe.
|
226 |
+
template_card_index (int, optional):
|
227 |
+
Index of template card to be used for preparing the recipe.
|
228 |
+
|
229 |
+
Methods:
|
230 |
+
prepare():
|
231 |
+
This overridden method is used for preparing the recipe
|
232 |
+
by arranging all the steps, refiners, and renderers in a sequential manner.
|
233 |
+
|
234 |
+
Raises:
|
235 |
+
AssertionError:
|
236 |
+
If both template and template_card_index are specified at the same time.
|
237 |
+
"""
|
238 |
|
|
|
239 |
# Base parameters
|
240 |
card: TaskCard = None
|
241 |
task: Task = None
|
|
|
262 |
test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
|
263 |
|
264 |
demos_pool_size: int = None
|
265 |
+
demos_pool: List[Dict[str, Any]] = None
|
266 |
num_demos: Optional[Union[int, List[int]]] = 0
|
267 |
demos_removed_from_data: bool = True
|
268 |
+
demos_pool_field_name: str = constants.demos_pool_field
|
269 |
|
|
|
270 |
demos_taken_from: str = "train"
|
271 |
demos_field: str = "demos"
|
272 |
sampler: Sampler = None
|
273 |
|
274 |
+
# do not push demos to instances whose "demos" field is already populated
|
275 |
+
skip_demoed_instances: bool = False
|
276 |
+
|
277 |
augmentor: Union[Augmentor, List[Augmentor]] = OptionalField(default=None)
|
278 |
|
279 |
steps: List[StreamingOperator] = InternalField(default_factory=list)
|
|
|
308 |
raise ValueError(
|
309 |
"When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers."
|
310 |
)
|
311 |
+
if self.demos_pool_size < self.max_demos_size + 1:
|
312 |
raise ValueError(
|
313 |
+
f"num_demos (got: {self.max_demos_size}) should not exceed demos_pool_size - 1 (got: {self.demos_pool_size}), (-1: to always allow filtering of a demo identical to the processed instance)."
|
314 |
)
|
315 |
+
if (
|
316 |
+
(not self.demos_pool)
|
317 |
+
and (self.demos_pool_size != sys.maxsize)
|
318 |
+
and self.loader_limit
|
319 |
+
and (self.demos_pool_size > self.loader_limit)
|
320 |
+
):
|
321 |
raise ValueError(
|
322 |
f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}"
|
323 |
)
|
|
|
432 |
self.loading,
|
433 |
self.metadata,
|
434 |
self.standardization,
|
|
|
435 |
]
|
436 |
|
437 |
self.inference = SequentialOperator()
|
438 |
|
439 |
+
self.inference.steps = [self.processing, self.verbalization, self.finalize]
|
440 |
|
441 |
def production_preprocess(self, task_instances):
|
442 |
ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
|
443 |
+
return list(self.metadata(ms)[constants.inference_stream])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
@property
|
446 |
def has_custom_demos_pool(self):
|
447 |
+
return self.demos_pool_size is not None and (
|
448 |
+
self.demos_pool_size > 0 or self.demos_pool_size == -1
|
449 |
+
)
|
450 |
|
451 |
@property
|
452 |
def use_demos(self):
|
|
|
455 |
def produce(self, task_instances):
|
456 |
"""Use the recipe in production to produce model ready query from standard task instance."""
|
457 |
self.before_process_multi_stream()
|
458 |
+
|
459 |
+
ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
|
460 |
+
# does not hurt to set metadata
|
461 |
+
# task_instances are assumed to be as if passed through self.standardization
|
462 |
+
ms = self.metadata(ms)
|
463 |
+
if not self.use_demos:
|
464 |
+
# go with task_instances all the way, it does not need other streams:
|
465 |
+
ms = self.inference(ms)
|
466 |
+
return list(ms[constants.inference_stream])
|
467 |
+
|
468 |
+
streams = self.inference_demos()
|
469 |
+
# streams stopped before processing
|
470 |
+
# ms is ready to join, it will get the demos from streams
|
471 |
+
streams[constants.inference_stream] = ms[constants.inference_stream]
|
472 |
+
# multi_stream = MultiStream(streams)
|
473 |
+
multi_stream = self.inference(streams)
|
474 |
return list(multi_stream[constants.inference_stream])
|
475 |
|
476 |
def reset(self):
|
|
|
534 |
augmentor.set_fields(self.card.task.augmentable_inputs)
|
535 |
self.processing.steps.append(augmentor)
|
536 |
|
537 |
+
# for backward compatibility, consume the demos instances even if not pushed into demos field of the ordinary instances,
|
538 |
+
# in order to use the very same ordinary instances as in back releases.
|
539 |
+
# one example of consume but not used, and indeed skips over a problematic (json-wise) input:
|
540 |
+
# prepare/cards/rag/end_to_end/clapnq.py
|
541 |
if self.has_custom_demos_pool:
|
542 |
+
if self.demos_pool:
|
543 |
+
self.processing.steps.append(
|
544 |
+
AddDemosPool(
|
545 |
+
demos_pool=self.demos_pool,
|
546 |
+
demos_pool_field_name=self.demos_pool_field_name,
|
547 |
+
)
|
548 |
+
)
|
549 |
+
else:
|
550 |
+
self.processing.steps.append(
|
551 |
+
CreateDemosPool(
|
552 |
+
from_stream=self.demos_taken_from,
|
553 |
+
demos_pool_size=self.demos_pool_size
|
554 |
+
if self.demos_pool is None
|
555 |
+
else None,
|
556 |
+
demos_removed_from_data=self.demos_removed_from_data,
|
557 |
+
to_field=self.demos_pool_field_name,
|
558 |
+
)
|
559 |
)
|
|
|
560 |
|
561 |
if self.use_demos:
|
562 |
if self.sampler is None:
|
|
|
573 |
if isinstance(self.num_demos, int):
|
574 |
self.verbalization.steps.append(
|
575 |
ConstantSizeSample(
|
576 |
+
from_field=self.demos_pool_field_name,
|
577 |
to_field=self.demos_field,
|
578 |
sampler=self.sampler,
|
579 |
sample_size=self.num_demos,
|
580 |
+
skip_demoed_instances=self.skip_demoed_instances,
|
581 |
)
|
582 |
)
|
583 |
self.verbalization.steps.append(
|
584 |
+
Set(
|
585 |
+
fields={
|
586 |
+
"recipe_metadata/num_demos": self.num_demos,
|
587 |
+
"recipe_metadata/demos_pool_size": self.demos_pool_size,
|
588 |
+
}
|
589 |
+
)
|
590 |
)
|
591 |
|
592 |
elif isinstance(self.num_demos, list):
|
593 |
self.verbalization.steps.append(
|
594 |
RandomSizeSample(
|
595 |
+
from_field=self.demos_pool_field_name,
|
596 |
to_field=self.demos_field,
|
597 |
sampler=self.sampler,
|
598 |
sample_sizes=self.num_demos,
|
599 |
+
skip_demoed_instances=self.skip_demoed_instances,
|
600 |
)
|
601 |
)
|
602 |
self.verbalization.steps.append(
|
603 |
GetLength(field="demos", to_field="recipe_metadata/num_demos")
|
604 |
)
|
605 |
+
self.verbalization.steps.append(
|
606 |
+
Set(
|
607 |
+
fields={"recipe_metadata/demos_pool_size": self.demos_pool_size}
|
608 |
+
)
|
609 |
+
)
|
610 |
+
|
611 |
else:
|
612 |
raise ValueError("num_demos must be int or List[int]")
|
613 |
|
|
|
623 |
template=self.template, demos_field=self.demos_field
|
624 |
)
|
625 |
)
|
626 |
+
|
627 |
else:
|
628 |
self.verbalization.steps.append(
|
629 |
+
Set(
|
630 |
+
fields={
|
631 |
+
"recipe_metadata/num_demos": 0,
|
632 |
+
"recipe_metadata/demos_pool_size": 0,
|
633 |
+
}
|
634 |
+
)
|
635 |
)
|
636 |
if isinstance(self.template, list):
|
637 |
self.verbalization.steps.append(
|
|
|
655 |
|
656 |
self.finalize.steps.append(FinalizeDataset(group_by=self.group_by))
|
657 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
658 |
def prepare(self):
|
659 |
assert (
|
660 |
self.template_card_index is None or self.template is None
|
|
|
701 |
raise ValueError(
|
702 |
"No template was specified in the the 'template' or 'template_card_index' recipe arguments, and no default templates are defined the card or task"
|
703 |
)
|
704 |
+
if self.use_demos:
|
705 |
+
assert (
|
706 |
+
self.demos_pool is not None
|
707 |
+
and isoftype(self.demos_pool, List[Dict[str, Any]])
|
708 |
+
) != (
|
709 |
+
self.demos_taken_from is not None
|
710 |
+
and self.demos_pool_size is not None
|
711 |
+
and self.demos_removed_from_data is not None
|
712 |
+
), (
|
713 |
+
"The demos_pool must be specified by exactly one of two ways: explicitly, as a list of instances coming through parameter "
|
714 |
+
+ "'demos_pool', or via parameters 'demos_taken_from', 'demos_pool_size', and 'demos_removed_from_data', "
|
715 |
+
+ "that together direct its production."
|
716 |
+
)
|
717 |
|
718 |
+
# now set self.demos_pool_size for the checks done by verify
|
719 |
+
if self.demos_pool:
|
720 |
+
self.demos_pool_size = len(self.demos_pool)
|
721 |
+
if self.demos_pool_size is not None and self.demos_pool_size == -1:
|
722 |
+
self.demos_pool_size = sys.maxsize
|
723 |
|
724 |
+
if isinstance(self.template, TemplatesList):
|
725 |
+
self.template = self.template.items
|
726 |
+
self.reset_pipeline()
|
727 |
|
|
|
|
|
728 |
|
729 |
+
@deprecation(version="2.0.0", alternative=DatasetRecipe)
|
730 |
+
class BaseRecipe(DatasetRecipe):
|
731 |
+
pass
|
732 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
733 |
|
734 |
+
@deprecation(version="2.0.0", alternative=DatasetRecipe)
|
735 |
+
class StandardRecipeWithIndexes(DatasetRecipe):
|
736 |
+
pass
|
|
|
737 |
|
|
|
|
|
|
|
|
|
738 |
|
739 |
+
@deprecation(version="2.0.0", alternative=DatasetRecipe)
|
740 |
+
class StandardRecipe(DatasetRecipe):
|
741 |
pass
|
struct_data_operators.py
CHANGED
@@ -679,7 +679,7 @@ class LoadJson(FieldOperator):
|
|
679 |
except json.JSONDecodeError:
|
680 |
return self.failure_value
|
681 |
else:
|
682 |
-
return json.loads(value)
|
683 |
|
684 |
|
685 |
class DumpJson(FieldOperator):
|
|
|
679 |
except json.JSONDecodeError:
|
680 |
return self.failure_value
|
681 |
else:
|
682 |
+
return json.loads(value, strict=False)
|
683 |
|
684 |
|
685 |
class DumpJson(FieldOperator):
|
task.py
CHANGED
@@ -40,25 +40,22 @@ def parse_string_types_instead_of_actual_objects(obj):
|
|
40 |
class Task(InstanceOperator, ArtifactFetcherMixin):
|
41 |
"""Task packs the different instance fields into dictionaries by their roles in the task.
|
42 |
|
43 |
-
|
44 |
input_fields (Union[Dict[str, str], List[str]]):
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
reference_fields (Union[Dict[str, str], List[str]]):
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
prediction_type (Optional[str]):
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
defaults (Optional[Dict[str, Any]]):
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
|
63 |
The output instance contains three fields:
|
64 |
1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
|
@@ -119,7 +116,7 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
119 |
self.prediction_type
|
120 |
)
|
121 |
|
122 |
-
def
|
123 |
if hasattr(self, "inputs") and self.inputs is not None:
|
124 |
depr_message = (
|
125 |
"The 'inputs' field is deprecated. Please use 'input_fields' instead."
|
@@ -130,6 +127,9 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
130 |
depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
|
131 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
132 |
|
|
|
|
|
|
|
133 |
if self.input_fields is None:
|
134 |
raise UnitxtError(
|
135 |
"Missing attribute in task: 'input_fields' not set.",
|
@@ -155,7 +155,11 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
155 |
f"will raise an exception.",
|
156 |
Documentation.ADDING_TASK,
|
157 |
)
|
158 |
-
data
|
|
|
|
|
|
|
|
|
159 |
if io_type == "input_fields":
|
160 |
self.input_fields = data
|
161 |
else:
|
@@ -290,6 +294,9 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
290 |
"media": instance.get("media", {}),
|
291 |
"recipe_metadata": instance.get("recipe_metadata", {}),
|
292 |
}
|
|
|
|
|
|
|
293 |
|
294 |
if stream_name == constants.inference_stream:
|
295 |
return result
|
|
|
40 |
class Task(InstanceOperator, ArtifactFetcherMixin):
|
41 |
"""Task packs the different instance fields into dictionaries by their roles in the task.
|
42 |
|
43 |
+
Args:
|
44 |
input_fields (Union[Dict[str, str], List[str]]):
|
45 |
+
Dictionary with string names of instance input fields and types of respective values.
|
46 |
+
In case a list is passed, each type will be assumed to be Any.
|
|
|
47 |
reference_fields (Union[Dict[str, str], List[str]]):
|
48 |
+
Dictionary with string names of instance output fields and types of respective values.
|
49 |
+
In case a list is passed, each type will be assumed to be Any.
|
50 |
+
metrics (List[str]):
|
51 |
+
List of names of metrics to be used in the task.
|
|
|
52 |
prediction_type (Optional[str]):
|
53 |
+
Need to be consistent with all used metrics. Defaults to None, which means that it will
|
54 |
+
be set to Any.
|
|
|
55 |
defaults (Optional[Dict[str, Any]]):
|
56 |
+
An optional dictionary with default values for chosen input/output keys. Needs to be
|
57 |
+
consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments.
|
58 |
+
Will not overwrite values if already provided in a given instance.
|
59 |
|
60 |
The output instance contains three fields:
|
61 |
1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
|
|
|
116 |
self.prediction_type
|
117 |
)
|
118 |
|
119 |
+
def task_deprecations(self):
|
120 |
if hasattr(self, "inputs") and self.inputs is not None:
|
121 |
depr_message = (
|
122 |
"The 'inputs' field is deprecated. Please use 'input_fields' instead."
|
|
|
127 |
depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
|
128 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
129 |
|
130 |
+
def verify(self):
|
131 |
+
self.task_deprecations()
|
132 |
+
|
133 |
if self.input_fields is None:
|
134 |
raise UnitxtError(
|
135 |
"Missing attribute in task: 'input_fields' not set.",
|
|
|
155 |
f"will raise an exception.",
|
156 |
Documentation.ADDING_TASK,
|
157 |
)
|
158 |
+
if isinstance(data, dict):
|
159 |
+
data = parse_type_dict(to_type_dict(data))
|
160 |
+
else:
|
161 |
+
data = {key: Any for key in data}
|
162 |
+
|
163 |
if io_type == "input_fields":
|
164 |
self.input_fields = data
|
165 |
else:
|
|
|
294 |
"media": instance.get("media", {}),
|
295 |
"recipe_metadata": instance.get("recipe_metadata", {}),
|
296 |
}
|
297 |
+
if "demos" in instance:
|
298 |
+
# for the case of recipe.skip_demoed_instances
|
299 |
+
result["demos"] = instance["demos"]
|
300 |
|
301 |
if stream_name == constants.inference_stream:
|
302 |
return result
|
templates.py
CHANGED
@@ -307,26 +307,27 @@ class PairwiseChoiceTemplate(InputOutputTemplate):
|
|
307 |
The answer field value should be of type Literal["choice_a", "choice_b", "tie"]
|
308 |
|
309 |
Args:
|
310 |
-
choice_a_field (str):
|
311 |
-
|
312 |
-
choice_b_field (str):
|
313 |
-
|
314 |
-
answer_field (str):
|
315 |
-
|
316 |
-
|
317 |
-
choice_a_label (str):
|
318 |
-
|
319 |
-
choice_b_label (str):
|
320 |
-
|
321 |
-
choice_tie_label (str):
|
322 |
-
|
323 |
-
shuffle (bool):
|
|
|
324 |
|
325 |
shuffle: 50% of the time:
|
326 |
1. The values of choice_a_field and choice_b_field will be swapped.
|
327 |
2. If the values of answer_field is choice_a_label, set it to choice_b_label.
|
328 |
-
|
329 |
-
|
330 |
|
331 |
"""
|
332 |
|
@@ -636,21 +637,22 @@ class MultipleChoiceTemplate(InputFormatTemplate):
|
|
636 |
class YesNoTemplate(InputFormatTemplate):
|
637 |
"""A template for generating binary Yes/No questions asking whether an input text is of a specific class.
|
638 |
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
|
|
654 |
"""
|
655 |
|
656 |
input_format: str = None
|
|
|
307 |
The answer field value should be of type Literal["choice_a", "choice_b", "tie"]
|
308 |
|
309 |
Args:
|
310 |
+
choice_a_field (str):
|
311 |
+
The field which contains choice_a value
|
312 |
+
choice_b_field (str):
|
313 |
+
The field which contains choice_b value
|
314 |
+
answer_field (str):
|
315 |
+
The field which contains the answer value.
|
316 |
+
Should be of type Literal["choice_1", "choice_2", "tie"]
|
317 |
+
choice_a_label (str):
|
318 |
+
The label of choice A answer as it is verbalized in the template.
|
319 |
+
choice_b_label (str):
|
320 |
+
The label of choice B answer as it is verbalized in the template.
|
321 |
+
choice_tie_label (str):
|
322 |
+
The label of a tie answer as it should be verbalized in the template.
|
323 |
+
shuffle (bool):
|
324 |
+
whether to shuffle the choices or not. This is done to take into account position bias.
|
325 |
|
326 |
shuffle: 50% of the time:
|
327 |
1. The values of choice_a_field and choice_b_field will be swapped.
|
328 |
2. If the values of answer_field is choice_a_label, set it to choice_b_label.
|
329 |
+
Else if the values of answer_field is choice_b_label, set it to choice_a_label.
|
330 |
+
Else if the value of answer_field is choice_tie_label, do nothing.
|
331 |
|
332 |
"""
|
333 |
|
|
|
637 |
class YesNoTemplate(InputFormatTemplate):
|
638 |
"""A template for generating binary Yes/No questions asking whether an input text is of a specific class.
|
639 |
|
640 |
+
Args:
|
641 |
+
input_format:
|
642 |
+
Defines the format of the question.
|
643 |
+
class_field:
|
644 |
+
Defines the field that contains the name of the class that this template
|
645 |
+
asks of.
|
646 |
+
label_field:
|
647 |
+
Defines the field which contains the true label of the input text. If a gold label is equal to the
|
648 |
+
value in class_name, then the correct output is self.yes_answer (by default, "Yes").
|
649 |
+
Otherwise the correct output is self.no_answer (by default, "No").
|
650 |
+
yes_answer:
|
651 |
+
The output value for when the gold label equals self.class_name.
|
652 |
+
Defaults to "Yes".
|
653 |
+
no_answer:
|
654 |
+
The output value for when the gold label differs from self.class_name.
|
655 |
+
Defaults to "No".
|
656 |
"""
|
657 |
|
658 |
input_format: str = None
|
type_utils.py
CHANGED
@@ -552,6 +552,9 @@ def strtype(typing_type) -> str:
|
|
552 |
- The function checks the `__origin__` attribute to determine the base type and formats
|
553 |
the type arguments accordingly.
|
554 |
"""
|
|
|
|
|
|
|
555 |
if not is_type(typing_type):
|
556 |
raise UnsupportedTypeError(typing_type)
|
557 |
|
|
|
552 |
- The function checks the `__origin__` attribute to determine the base type and formats
|
553 |
the type arguments accordingly.
|
554 |
"""
|
555 |
+
if isinstance(typing_type, str):
|
556 |
+
return typing_type
|
557 |
+
|
558 |
if not is_type(typing_type):
|
559 |
raise UnsupportedTypeError(typing_type)
|
560 |
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.16.
|
|
|
1 |
+
version = "1.16.1"
|