Elron commited on
Commit
24df49f
·
verified ·
1 Parent(s): 88c61d3

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -30,7 +30,7 @@ In the dynamic landscape of generative NLP, traditional text processing pipeline
30
  ![license](https://img.shields.io/github/license/ibm/unitxt)
31
  ![python](https://img.shields.io/badge/python-3.8%20|%203.9-blue)
32
  ![tests](https://img.shields.io/github/actions/workflow/status/ibm/unitxt/library_tests.yml?branch=main&label=tests)
33
- [![codecov](https://codecov.io/gh/IBM/unitxt/branch/main/graph/badge.svg?token=mlrWq9cwz3)](https://codecov.io/gh/IBM/unitxt)
34
  ![Read the Docs](https://img.shields.io/readthedocs/unitxt)
35
  [![downloads](https://static.pepy.tech/personalized-badge/unitxt?period=total&units=international_system&left_color=grey&right_color=green&left_text=downloads)](https://pepy.tech/project/unitxt)
36
 
 
30
  ![license](https://img.shields.io/github/license/ibm/unitxt)
31
  ![python](https://img.shields.io/badge/python-3.8%20|%203.9-blue)
32
  ![tests](https://img.shields.io/github/actions/workflow/status/ibm/unitxt/library_tests.yml?branch=main&label=tests)
33
+ [![Coverage Status](https://coveralls.io/repos/github/IBM/unitxt/badge.svg)](https://coveralls.io/github/IBM/unitxt)
34
  ![Read the Docs](https://img.shields.io/readthedocs/unitxt)
35
  [![downloads](https://static.pepy.tech/personalized-badge/unitxt?period=total&units=international_system&left_color=grey&right_color=green&left_text=downloads)](https://pepy.tech/project/unitxt)
36
 
api.py CHANGED
@@ -18,7 +18,7 @@ from .metric_utils import EvaluationResults, _compute, _inference_post_process
18
  from .operator import SourceOperator
19
  from .schema import UNITXT_DATASET_SCHEMA, loads_instance
20
  from .settings_utils import get_constants, get_settings
21
- from .standard import StandardRecipe
22
  from .task import Task
23
 
24
  logger = get_logger()
@@ -35,7 +35,7 @@ def load(source: Union[SourceOperator, str]):
35
  return source().to_dataset()
36
 
37
 
38
- def _get_recipe_from_query(dataset_query: str) -> StandardRecipe:
39
  dataset_query = dataset_query.replace("sys_prompt", "instruction")
40
  try:
41
  dataset_stream, _ = fetch_artifact(dataset_query)
@@ -44,14 +44,14 @@ def _get_recipe_from_query(dataset_query: str) -> StandardRecipe:
44
  return dataset_stream
45
 
46
 
47
- def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> StandardRecipe:
48
- recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
49
  for param in dataset_params.keys():
50
  assert param in recipe_attributes, (
51
- f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
52
  f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
53
  )
54
- return StandardRecipe(**dataset_params)
55
 
56
 
57
  def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
@@ -76,8 +76,8 @@ def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None)
76
  )
77
 
78
 
79
- def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> StandardRecipe:
80
- if isinstance(dataset_query, StandardRecipe):
81
  return dataset_query
82
 
83
  _verify_dataset_args(dataset_query, kwargs)
@@ -230,7 +230,7 @@ def infer(
230
  return_data: bool = False,
231
  return_log_probs: bool = False,
232
  return_meta_data: bool = False,
233
- previous_messages: Optional[list[dict[str, str]]] = None,
234
  **kwargs,
235
  ):
236
  dataset = produce(instance_or_instances, dataset_query, **kwargs)
@@ -283,7 +283,7 @@ def select(
283
  engine: OptionSelectingByLogProbsInferenceEngine,
284
  dataset_query: Optional[str] = None,
285
  return_data: bool = False,
286
- previous_messages: Optional[list[dict[str, str]]] = None,
287
  **kwargs,
288
  ):
289
  dataset = produce(instance_or_instances, dataset_query, **kwargs)
 
18
  from .operator import SourceOperator
19
  from .schema import UNITXT_DATASET_SCHEMA, loads_instance
20
  from .settings_utils import get_constants, get_settings
21
+ from .standard import DatasetRecipe
22
  from .task import Task
23
 
24
  logger = get_logger()
 
35
  return source().to_dataset()
36
 
37
 
38
+ def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
39
  dataset_query = dataset_query.replace("sys_prompt", "instruction")
40
  try:
41
  dataset_stream, _ = fetch_artifact(dataset_query)
 
44
  return dataset_stream
45
 
46
 
47
+ def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
48
+ recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
49
  for param in dataset_params.keys():
50
  assert param in recipe_attributes, (
51
+ f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
52
  f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
53
  )
54
+ return DatasetRecipe(**dataset_params)
55
 
56
 
57
  def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
 
76
  )
77
 
78
 
79
+ def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
80
+ if isinstance(dataset_query, DatasetRecipe):
81
  return dataset_query
82
 
83
  _verify_dataset_args(dataset_query, kwargs)
 
230
  return_data: bool = False,
231
  return_log_probs: bool = False,
232
  return_meta_data: bool = False,
233
+ previous_messages: Optional[List[Dict[str, str]]] = None,
234
  **kwargs,
235
  ):
236
  dataset = produce(instance_or_instances, dataset_query, **kwargs)
 
283
  engine: OptionSelectingByLogProbsInferenceEngine,
284
  dataset_query: Optional[str] = None,
285
  return_data: bool = False,
286
+ previous_messages: Optional[List[Dict[str, str]]] = None,
287
  **kwargs,
288
  ):
289
  dataset = produce(instance_or_instances, dataset_query, **kwargs)
benchmark.py CHANGED
@@ -5,7 +5,7 @@ from .dataclass import NonPositionalField
5
  from .formats import Format
6
  from .fusion import FixedFusion, WeightedFusion
7
  from .operator import SourceOperator
8
- from .standard import StandardRecipe
9
  from .stream import MultiStream
10
  from .system_prompts import SystemPrompt
11
 
@@ -22,7 +22,7 @@ class BaseBenchmark(SourceOperator):
22
 
23
 
24
  class Benchmark(BaseBenchmark):
25
- subsets: Dict[str, Union[StandardRecipe, BaseBenchmark]]
26
 
27
  max_total_samples: int = None
28
  max_samples_per_subset: int = None
 
5
  from .formats import Format
6
  from .fusion import FixedFusion, WeightedFusion
7
  from .operator import SourceOperator
8
+ from .standard import DatasetRecipe
9
  from .stream import MultiStream
10
  from .system_prompts import SystemPrompt
11
 
 
22
 
23
 
24
  class Benchmark(BaseBenchmark):
25
+ subsets: Dict[str, Union[DatasetRecipe, BaseBenchmark]]
26
 
27
  max_total_samples: int = None
28
  max_samples_per_subset: int = None
blocks.py CHANGED
@@ -18,7 +18,7 @@ from .operators import (
18
  )
19
  from .processors import ToString, ToStringStripped
20
  from .recipe import SequentialRecipe
21
- from .splitters import RandomSampler, Sample, SliceSplit, SplitRandomMix
22
  from .stream import MultiStream
23
  from .struct_data_operators import (
24
  ConstructTableFromRowsCols,
 
18
  )
19
  from .processors import ToString, ToStringStripped
20
  from .recipe import SequentialRecipe
21
+ from .splitters import AssignDemosToInstance, RandomSampler, SliceSplit, SplitRandomMix
22
  from .stream import MultiStream
23
  from .struct_data_operators import (
24
  ConstructTableFromRowsCols,
card.py CHANGED
@@ -12,16 +12,17 @@ from .templates import Template, TemplatesDict, TemplatesList
12
  class TaskCard(Artifact):
13
  """TaskCard delineates the phases in transforming the source dataset into model input, and specifies the metrics for evaluation of model output.
14
 
15
- Attributes:
16
- loader: specifies the source address and the loading operator that can access that source and transform it into a unitxt multistream.
17
-
18
- preprocess_steps: list of unitxt operators to process the data source into model input.
19
-
20
- task: specifies the fields (of the already (pre)processed instance) making the inputs, the fields making the outputs, and the metrics to be used for evaluating the model output.
21
-
22
- templates: format strings to be applied on the input fields (specified by the task) and the output fields. The template also carries the instructions and the list of postprocessing steps, to be applied to the model output.
23
-
24
- default_template: a default template for tasks with very specific task dataset specific template
 
25
  """
26
 
27
  loader: Loader
 
12
  class TaskCard(Artifact):
13
  """TaskCard delineates the phases in transforming the source dataset into model input, and specifies the metrics for evaluation of model output.
14
 
15
+ Args:
16
+ loader:
17
+ specifies the source address and the loading operator that can access that source and transform it into a unitxt multistream.
18
+ preprocess_steps:
19
+ list of unitxt operators to process the data source into model input.
20
+ task:
21
+ specifies the fields (of the already (pre)processed instance) making the inputs, the fields making the outputs, and the metrics to be used for evaluating the model output.
22
+ templates:
23
+ format strings to be applied on the input fields (specified by the task) and the output fields. The template also carries the instructions and the list of postprocessing steps, to be applied to the model output.
24
+ default_template:
25
+ a default template for tasks with very specific task dataset specific template
26
  """
27
 
28
  loader: Loader
dataset_utils.py CHANGED
@@ -5,7 +5,7 @@ from .logging_utils import get_logger
5
  from .parsing_utils import parse_key_equals_value_string_to_dict
6
  from .register import _reset_env_local_catalogs, register_all_artifacts
7
  from .settings_utils import get_settings
8
- from .standard import BaseRecipe
9
 
10
  logger = get_logger()
11
  settings = get_settings()
@@ -24,7 +24,7 @@ def parse(query: str):
24
 
25
 
26
  def get_dataset_artifact(dataset):
27
- if isinstance(dataset, BaseRecipe):
28
  return dataset
29
  assert isinstance(
30
  dataset, str
 
5
  from .parsing_utils import parse_key_equals_value_string_to_dict
6
  from .register import _reset_env_local_catalogs, register_all_artifacts
7
  from .settings_utils import get_settings
8
+ from .standard import DatasetRecipe
9
 
10
  logger = get_logger()
11
  settings = get_settings()
 
24
 
25
 
26
  def get_dataset_artifact(dataset):
27
+ if isinstance(dataset, DatasetRecipe):
28
  return dataset
29
  assert isinstance(
30
  dataset, str
deprecation_utils.py CHANGED
@@ -18,19 +18,24 @@ def compare_versions(version1, version2):
18
  """Compare two semantic versioning strings and determine their relationship.
19
 
20
  Parameters:
21
- - version1 (str): The first version string to compare.
22
- - version2 (str): The second version string to compare.
 
 
23
 
24
  Returns:
25
- - int: -1 if version1 < version2, 1 if version1 > version2, 0 if equal.
26
 
27
  Example:
28
- >>> compare_versions("1.2.0", "1.2.3")
29
- -1
30
- >>> compare_versions("1.3.0", "1.2.8")
31
- 1
32
- >>> compare_versions("1.0.0", "1.0.0")
33
- 0
 
 
 
34
  """
35
  parts1 = [int(part) for part in version1.split(".")]
36
  parts2 = [int(part) for part in version2.split(".")]
 
18
  """Compare two semantic versioning strings and determine their relationship.
19
 
20
  Parameters:
21
+ version1 (str):
22
+ The first version string to compare.
23
+ version2 (str):
24
+ The second version string to compare.
25
 
26
  Returns:
27
+ int: -1 if version1 < version2, 1 if version1 > version2, 0 if equal.
28
 
29
  Example:
30
+ .. code-block:: text
31
+
32
+ >>> compare_versions("1.2.0", "1.2.3")
33
+ -1
34
+ >>> compare_versions("1.3.0", "1.2.8")
35
+ 1
36
+ >>> compare_versions("1.0.0", "1.0.0")
37
+ 0
38
+
39
  """
40
  parts1 = [int(part) for part in version1.split(".")]
41
  parts2 = [int(part) for part in version2.split(".")]
dialog_operators.py CHANGED
@@ -27,12 +27,17 @@ class SerializeDialog(InstanceFieldOperator):
27
  of system responses and can operate on a per-turn basis or aggregate the entire
28
  dialog.
29
 
30
- Attributes:
31
- field (str): The field in the input data that contains the dialog.
32
- to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
33
- last_user_turn_to_field (Optional[str]): Field to store the last user turn.
34
- last_system_turn_to_field (Optional[str]): Field to store the last system turn.
35
- context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
 
 
 
 
 
36
  """
37
 
38
  format: SystemFormat = None
@@ -100,12 +105,17 @@ class SerializeOpenAiFormatDialog(SerializeDialog):
100
  of system responses and can operate on a per-turn basis or aggregate the entire
101
  dialog.
102
 
103
- Attributes:
104
- field (str): The field in the input data that contains the dialog.
105
- to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
106
- last_user_turn_to_field (Optional[str]): Field to store the last user turn.
107
- last_system_turn_to_field (Optional[str]): Field to store the last system turn.
108
- context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
 
 
 
 
 
109
  """
110
 
111
  is_last_turn_user_only: bool = True
 
27
  of system responses and can operate on a per-turn basis or aggregate the entire
28
  dialog.
29
 
30
+ Args:
31
+ field (str):
32
+ The field in the input data that contains the dialog.
33
+ to_field (Optional[str]):
34
+ The field in the output data where the serialized dialog will be stored.
35
+ last_user_turn_to_field (Optional[str]):
36
+ Field to store the last user turn.
37
+ last_system_turn_to_field (Optional[str]):
38
+ Field to store the last system turn.
39
+ context_field (Optional[str]):
40
+ Field that contains additional context to be prepended to the dialog.
41
  """
42
 
43
  format: SystemFormat = None
 
105
  of system responses and can operate on a per-turn basis or aggregate the entire
106
  dialog.
107
 
108
+ Args:
109
+ field (str):
110
+ The field in the input data that contains the dialog.
111
+ to_field (Optional[str]):
112
+ The field in the output data where the serialized dialog will be stored.
113
+ last_user_turn_to_field (Optional[str]):
114
+ Field to store the last user turn.
115
+ last_system_turn_to_field (Optional[str]):
116
+ Field to store the last system turn.
117
+ context_field (Optional[str]):
118
+ Field that contains additional context to be prepended to the dialog.
119
  """
120
 
121
  is_last_turn_user_only: bool = True
error_utils.py CHANGED
@@ -27,10 +27,12 @@ def additional_info(path: str) -> str:
27
  class UnitxtError(Exception):
28
  """Exception raised for Unitxt errors.
29
 
30
- Attributes:
31
- message : str -- explanation of the error
32
- additional_info_id : Optional[str] -- relative path to additional documentation on web
33
- If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
 
 
34
 
35
  """
36
 
@@ -43,10 +45,12 @@ class UnitxtError(Exception):
43
  class UnitxtWarning:
44
  """Object to format warning message to log.
45
 
46
- Attributes:
47
- message -- explanation of the warning
48
- additional_info_id : Optional[str] -- relative path to additional documentation on web
49
- If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
 
 
50
  """
51
 
52
  def __init__(self, message: str, additional_info_id: Optional[str] = None):
 
27
  class UnitxtError(Exception):
28
  """Exception raised for Unitxt errors.
29
 
30
+ Args:
31
+ message (str):
32
+ explanation of the error
33
+ additional_info_id (Optional[str]):
34
+ relative path to additional documentation on web
35
+ If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
36
 
37
  """
38
 
 
45
  class UnitxtWarning:
46
  """Object to format warning message to log.
47
 
48
+ Args:
49
+ message (str):
50
+ explanation of the warning
51
+ additional_info_id (Optional[str]):
52
+ relative path to additional documentation on web
53
+ If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
54
  """
55
 
56
  def __init__(self, message: str, additional_info_id: Optional[str] = None):
image_operators.py CHANGED
@@ -216,13 +216,15 @@ class GridLines(ImageAugmentor):
216
  class PixelNoise(ImageAugmentor):
217
  """A class that overlays a mask of randomly colored nxn squares across an image based on a specified noise rate.
218
 
219
- Attributes:
220
- square_size (int): Size of each square in pixels.
221
-
222
- noise_rate (float): Proportion of the image that should be affected by noise (0 to 1).
 
223
 
224
  Methods:
225
- process_image(image): Adds the random square mask to the provided image and returns the modified image.
 
226
  """
227
 
228
  square_size: int = 1
 
216
  class PixelNoise(ImageAugmentor):
217
  """A class that overlays a mask of randomly colored nxn squares across an image based on a specified noise rate.
218
 
219
+ Args:
220
+ square_size (int):
221
+ Size of each square in pixels.
222
+ noise_rate (float):
223
+ Proportion of the image that should be affected by noise (0 to 1).
224
 
225
  Methods:
226
+ process_image(image):
227
+ Adds the random square mask to the provided image and returns the modified image.
228
  """
229
 
230
  square_size: int = 1
inference.py CHANGED
@@ -9,6 +9,7 @@ import sys
9
  import time
10
  import uuid
11
  from collections import Counter
 
12
  from typing import (
13
  Any,
14
  Dict,
@@ -63,8 +64,8 @@ class StandardAPIParamsMixin(Artifact):
63
  n: Optional[int] = None
64
  parallel_tool_calls: Optional[bool] = None
65
  service_tier: Optional[Literal["auto", "default"]] = None
66
- credentials: Optional[dict[str, str]] = {}
67
- extra_headers: Optional[dict[str, str]] = None
68
 
69
 
70
  def get_model_and_label_id(model_name, label):
@@ -1171,8 +1172,8 @@ class OptionSelectingByLogProbsInferenceEngine:
1171
  for option in instance["task_data"]["options"]
1172
  ]
1173
 
1174
- dataset_with_options_logprobs: list[
1175
- list[dict[str, float | str]]
1176
  ] = self.get_options_log_probs(dataset_with_options)
1177
 
1178
  dataset_iterator = iter(dataset_with_options_logprobs)
@@ -1469,6 +1470,13 @@ class OpenAiInferenceEngineParams(Artifact):
1469
  service_tier: Optional[Literal["auto", "default"]] = None
1470
 
1471
 
 
 
 
 
 
 
 
1472
  class OpenAiInferenceEngine(
1473
  InferenceEngine,
1474
  LogProbInferenceEngine,
@@ -1485,6 +1493,7 @@ class OpenAiInferenceEngine(
1485
  base_url: Optional[str] = None
1486
  default_headers: Dict[str, str] = {}
1487
  credentials: CredentialsOpenAi = {}
 
1488
 
1489
  def get_engine_id(self) -> str:
1490
  return get_model_and_label_id(self.model_name, self.label)
@@ -1528,52 +1537,76 @@ class OpenAiInferenceEngine(
1528
  if v is not None
1529
  }
1530
 
1531
- def _infer(
1532
  self,
1533
  dataset: Union[List[Dict[str, Any]], Dataset],
 
1534
  return_meta_data: bool = False,
1535
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
 
1536
  outputs = []
1537
- for instance in tqdm(dataset, desc="Inferring with openAI API"):
1538
- messages = self.to_messages(instance)
1539
- response = self.client.chat.completions.create(
1540
- messages=messages,
1541
- model=self.model_name,
1542
- **self._get_completion_kwargs(),
1543
- )
1544
- prediction = response.choices[0].message.content
1545
- output = self.get_return_object(prediction, response, return_meta_data)
1546
-
1547
- outputs.append(output)
1548
 
1549
  return outputs
1550
 
 
 
 
 
 
 
 
 
 
 
 
1551
  def _infer_log_probs(
1552
  self,
1553
  dataset: Union[List[Dict[str, Any]], Dataset],
1554
  return_meta_data: bool = False,
1555
  ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
1556
- outputs = []
1557
- for instance in tqdm(dataset, desc="Inferring with openAI API"):
1558
- messages = self.to_messages(instance)
1559
- response = self.client.chat.completions.create(
1560
- messages=messages,
1561
- model=self.model_name,
1562
- **self._get_completion_kwargs(),
1563
- )
1564
- top_logprobs_response = response.choices[0].logprobs.content
1565
- pred_output = [
1566
- {
1567
- "top_tokens": [
1568
- {"text": obj.token, "logprob": obj.logprob}
1569
- for obj in generated_token.top_logprobs
1570
- ]
1571
- }
1572
- for generated_token in top_logprobs_response
1573
- ]
1574
- output = self.get_return_object(pred_output, response, return_meta_data)
1575
- outputs.append(output)
1576
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1577
 
1578
  def get_return_object(self, predict_result, response, return_meta_data):
1579
  if return_meta_data:
@@ -1807,16 +1840,19 @@ class WMLInferenceEngineBase(
1807
  ):
1808
  """Base for classes running inference using ibm-watsonx-ai.
1809
 
1810
- Attributes:
1811
- credentials (Dict[str, str], optional): By default, it is created by a class
 
1812
  instance which tries to retrieve proper environment variables
1813
  ("WML_URL", "WML_PROJECT_ID", "WML_SPACE_ID", "WML_APIKEY", "WML_USERNAME", "WML_PASSWORD").
1814
  However, a dictionary with the following keys: "url", "apikey", "project_id", "space_id",
1815
  "username", "password".
1816
  can be directly provided instead.
1817
- model_name (str, optional): ID of a model to be used for inference. Mutually
 
1818
  exclusive with 'deployment_id'.
1819
- deployment_id (str, optional): Deployment ID of a tuned model to be used for
 
1820
  inference. Mutually exclusive with 'model_name'.
1821
  parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
1822
  Defines inference parameters and their values. Deprecated attribute, please pass respective
@@ -2077,9 +2113,10 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
2077
 
2078
  If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
2079
 
2080
- Attributes:
2081
- concurrency_limit (int): Number of concurrent requests sent to a model. Default is 10,
2082
- which is also the maximum value.
 
2083
 
2084
  Examples:
2085
  .. code-block:: python
@@ -2207,10 +2244,11 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2207
  concatenate images within an instance into a single image and adjust your query
2208
  accordingly (if necessary).
2209
 
2210
- Attributes:
2211
- image_encoder (EncodeImageToString, optional): operator which encodes images in
2212
- given format to base64 strings required by service. You should specify it when
2213
- you are using images in your inputs.
 
2214
 
2215
  Example:
2216
  .. code-block:: python
 
9
  import time
10
  import uuid
11
  from collections import Counter
12
+ from multiprocessing.pool import ThreadPool
13
  from typing import (
14
  Any,
15
  Dict,
 
64
  n: Optional[int] = None
65
  parallel_tool_calls: Optional[bool] = None
66
  service_tier: Optional[Literal["auto", "default"]] = None
67
+ credentials: Optional[Dict[str, str]] = {}
68
+ extra_headers: Optional[Dict[str, str]] = None
69
 
70
 
71
  def get_model_and_label_id(model_name, label):
 
1172
  for option in instance["task_data"]["options"]
1173
  ]
1174
 
1175
+ dataset_with_options_logprobs: List[
1176
+ List[Dict[str, Union[float, str]]]
1177
  ] = self.get_options_log_probs(dataset_with_options)
1178
 
1179
  dataset_iterator = iter(dataset_with_options_logprobs)
 
1470
  service_tier: Optional[Literal["auto", "default"]] = None
1471
 
1472
 
1473
+ def run_with_imap(func):
1474
+ def inner(self, args):
1475
+ return func(self, *args)
1476
+
1477
+ return inner
1478
+
1479
+
1480
  class OpenAiInferenceEngine(
1481
  InferenceEngine,
1482
  LogProbInferenceEngine,
 
1493
  base_url: Optional[str] = None
1494
  default_headers: Dict[str, str] = {}
1495
  credentials: CredentialsOpenAi = {}
1496
+ num_parallel_requests: int = 20
1497
 
1498
  def get_engine_id(self) -> str:
1499
  return get_model_and_label_id(self.model_name, self.label)
 
1537
  if v is not None
1538
  }
1539
 
1540
+ def _parallel_infer(
1541
  self,
1542
  dataset: Union[List[Dict[str, Any]], Dataset],
1543
+ infer_func,
1544
  return_meta_data: bool = False,
1545
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1546
+ inputs = [(instance, return_meta_data) for instance in dataset]
1547
  outputs = []
1548
+ with ThreadPool(processes=self.num_parallel_requests) as pool:
1549
+ for output in tqdm(
1550
+ pool.imap(infer_func, inputs),
1551
+ total=len(inputs),
1552
+ desc=f"Inferring with {self.__class__.__name__}",
1553
+ ):
1554
+ outputs.append(output)
 
 
 
 
1555
 
1556
  return outputs
1557
 
1558
+ def _infer(
1559
+ self,
1560
+ dataset: Union[List[Dict[str, Any]], Dataset],
1561
+ return_meta_data: bool = False,
1562
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1563
+ return self._parallel_infer(
1564
+ dataset=dataset,
1565
+ return_meta_data=return_meta_data,
1566
+ infer_func=self._get_chat_completion,
1567
+ )
1568
+
1569
  def _infer_log_probs(
1570
  self,
1571
  dataset: Union[List[Dict[str, Any]], Dataset],
1572
  return_meta_data: bool = False,
1573
  ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
1574
+ return self._parallel_infer(
1575
+ dataset=dataset,
1576
+ return_meta_data=return_meta_data,
1577
+ infer_func=self._get_logprobs,
1578
+ )
1579
+
1580
+ @run_with_imap
1581
+ def _get_chat_completion(self, instance, return_meta_data):
1582
+ messages = self.to_messages(instance)
1583
+ response = self.client.chat.completions.create(
1584
+ messages=messages,
1585
+ model=self.model_name,
1586
+ **self._get_completion_kwargs(),
1587
+ )
1588
+ prediction = response.choices[0].message.content
1589
+ return self.get_return_object(prediction, response, return_meta_data)
1590
+
1591
+ @run_with_imap
1592
+ def _get_logprobs(self, instance, return_meta_data):
1593
+ messages = self.to_messages(instance)
1594
+ response = self.client.chat.completions.create(
1595
+ messages=messages,
1596
+ model=self.model_name,
1597
+ **self._get_completion_kwargs(),
1598
+ )
1599
+ top_logprobs_response = response.choices[0].logprobs.content
1600
+ pred_output = [
1601
+ {
1602
+ "top_tokens": [
1603
+ {"text": obj.token, "logprob": obj.logprob}
1604
+ for obj in generated_token.top_logprobs
1605
+ ]
1606
+ }
1607
+ for generated_token in top_logprobs_response
1608
+ ]
1609
+ return self.get_return_object(pred_output, response, return_meta_data)
1610
 
1611
  def get_return_object(self, predict_result, response, return_meta_data):
1612
  if return_meta_data:
 
1840
  ):
1841
  """Base for classes running inference using ibm-watsonx-ai.
1842
 
1843
+ Args:
1844
+ credentials (Dict[str, str], optional):
1845
+ By default, it is created by a class
1846
  instance which tries to retrieve proper environment variables
1847
  ("WML_URL", "WML_PROJECT_ID", "WML_SPACE_ID", "WML_APIKEY", "WML_USERNAME", "WML_PASSWORD").
1848
  However, a dictionary with the following keys: "url", "apikey", "project_id", "space_id",
1849
  "username", "password".
1850
  can be directly provided instead.
1851
+ model_name (str, optional):
1852
+ ID of a model to be used for inference. Mutually
1853
  exclusive with 'deployment_id'.
1854
+ deployment_id (str, optional):
1855
+ Deployment ID of a tuned model to be used for
1856
  inference. Mutually exclusive with 'model_name'.
1857
  parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
1858
  Defines inference parameters and their values. Deprecated attribute, please pass respective
 
2113
 
2114
  If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
2115
 
2116
+ Args:
2117
+ concurrency_limit (int):
2118
+ Number of concurrent requests sent to a model. Default is 10,
2119
+ which is also the maximum value.
2120
 
2121
  Examples:
2122
  .. code-block:: python
 
2244
  concatenate images within an instance into a single image and adjust your query
2245
  accordingly (if necessary).
2246
 
2247
+ Args:
2248
+ image_encoder (EncodeImageToString, optional):
2249
+ operator which encodes images in
2250
+ given format to base64 strings required by service. You should specify it when
2251
+ you are using images in your inputs.
2252
 
2253
  Example:
2254
  .. code-block:: python
llm_as_judge.py CHANGED
@@ -1,6 +1,6 @@
1
  import itertools
2
  from difflib import get_close_matches
3
- from typing import List, Optional, Union
4
 
5
  from .api import infer
6
  from .artifact import fetch_artifact
@@ -145,7 +145,7 @@ class LLMJudge(BulkInstanceMetric):
145
  )
146
  return
147
 
148
- def get_contexts(self, task_data: list[dict[str, any]]) -> list[dict[str, str]]:
149
  return [
150
  get_parsed_context(
151
  {
@@ -161,7 +161,7 @@ class LLMJudge(BulkInstanceMetric):
161
  instances: list,
162
  task: Task,
163
  template: Template,
164
- previous_messages: Optional[list[dict[str, str]]] = None,
165
  ):
166
  outputs_dataset = infer(
167
  instances,
@@ -172,11 +172,11 @@ class LLMJudge(BulkInstanceMetric):
172
  return_data=True,
173
  previous_messages=previous_messages,
174
  )
175
- prompts: list[str] = [instance["source"] for instance in outputs_dataset]
176
- raw_predictions: list[str] = [
177
  instance["raw_prediction"] for instance in outputs_dataset
178
  ]
179
- predictions: list[str] = [
180
  instance["prediction"] for instance in outputs_dataset
181
  ]
182
  return (prompts, raw_predictions, predictions)
@@ -274,7 +274,7 @@ class LLMJudgeDirect(LLMJudge):
274
  raise Exception(
275
  f"The type of the criteria must be 'CriteriaWithOptions', instead it is of type '{type(self.criteria)}'"
276
  )
277
- criterias: list[CriteriaWithOptions] = [self.criteria] * eval_count
278
  unique_criterias = list({criteria.name for criteria in criterias})
279
  self.logger.info(f"Criteria names are '{', '.join(unique_criterias)}'")
280
  return criterias
@@ -289,8 +289,8 @@ class LLMJudgeDirect(LLMJudge):
289
  option_selection_outputs,
290
  selections,
291
  evaluations_count,
292
- criterias: list[CriteriaWithOptions],
293
- ) -> list[dict[str, any]]:
294
  positional_bias = None
295
  if self.check_positional_bias:
296
  positional_bias = [
@@ -353,9 +353,9 @@ class LLMJudgeDirect(LLMJudge):
353
 
354
  def compute(
355
  self,
356
- references: list[list[str]],
357
- predictions: list[str],
358
- task_data: list[dict[str, any]],
359
  ) -> dict:
360
  self.logger.info(
361
  f'Starting evaluation with evaluator "{self.evaluator_name}" and provider "{self.inference_engine.get_pretty_print_name()}'
@@ -545,7 +545,7 @@ class LLMJudgePairwise(LLMJudge):
545
  f"The type of the criteria must be 'Criteria', instead it is of type '{type(self.criteria)}'"
546
  )
547
 
548
- criterias: list[Criteria] = [self.criteria] * eval_count
549
 
550
  unique_criterias = list({criteria.name for criteria in criterias})
551
  self.logger.info(f"Criteria names are '{', '.join(unique_criterias)}'")
@@ -553,7 +553,7 @@ class LLMJudgePairwise(LLMJudge):
553
 
554
  def get_instance_results(
555
  self,
556
- instance_predictions: dict[str, str],
557
  assessment_prompts,
558
  assessment_outputs,
559
  summarization_prompts,
@@ -728,7 +728,7 @@ class LLMJudgePairwise(LLMJudge):
728
  all_results["criteria"] = criteria.to_json()
729
  return self.clean_results(all_results)
730
 
731
- def parse_prediction_to_dict(self, prediction: Union[dict[str, str], list[str]]):
732
  if isinstance(prediction, list):
733
  return {f"{key + 1}": value for key, value in enumerate(prediction)}
734
 
@@ -740,15 +740,15 @@ class LLMJudgePairwise(LLMJudge):
740
  )
741
 
742
  def convert_predictions_to_dicts(
743
- self, predictions: Union[list[dict[str, str], list[str]]]
744
  ):
745
  return [self.parse_prediction_to_dict(prediction) for prediction in predictions]
746
 
747
  def compute(
748
  self,
749
- references: list[list[str]],
750
- predictions: Union[list[dict[str, str], list[str]]],
751
- task_data: list[dict[str, str]],
752
  ) -> dict:
753
  self.logger.info(
754
  f'Starting evaluation with evaluator "{self.evaluator_name}" and provider {self.inference_engine.get_pretty_print_name()}'
@@ -775,8 +775,8 @@ class LLMJudgePairwise(LLMJudge):
775
  f"The evaluation will perform {sum(contests_count_list) * [1,2][self.check_positional_bias]} ({' + '.join([f'{c * [1,2][self.check_positional_bias]}' for c in contests_count_list])}) pairwise comparisons"
776
  )
777
 
778
- response_pairs_list: list[list[list[str]]] = []
779
- option_pairs_list: list[list[list[str]]] = []
780
  predictions_names = set(predictions[0].keys())
781
  for i, combination_indexes in enumerate(combination_indexes_list):
782
  instance_predictions = predictions[i]
@@ -786,8 +786,8 @@ class LLMJudgePairwise(LLMJudge):
786
  f"The set of prediction names is different between instance 0 and instance {i}. In prediction 0, it is {sorted(predictions_names)}. In prediction {i}, it is {sorted(instance_predictions_names)}. Make sure the same number of predictions is passed for all instances."
787
  )
788
 
789
- response_pairs: list[list[str]] = []
790
- option_pairs: list[list[str]] = []
791
  for combination in combination_indexes:
792
  (idx_1, idx_2) = combination
793
  response_name_1 = instance_predictions_names[idx_1]
 
1
  import itertools
2
  from difflib import get_close_matches
3
+ from typing import Any, Dict, List, Optional, Union
4
 
5
  from .api import infer
6
  from .artifact import fetch_artifact
 
145
  )
146
  return
147
 
148
+ def get_contexts(self, task_data: List[Dict[str, Any]]) -> List[Dict[str, str]]:
149
  return [
150
  get_parsed_context(
151
  {
 
161
  instances: list,
162
  task: Task,
163
  template: Template,
164
+ previous_messages: Optional[List[Dict[str, str]]] = None,
165
  ):
166
  outputs_dataset = infer(
167
  instances,
 
172
  return_data=True,
173
  previous_messages=previous_messages,
174
  )
175
+ prompts: List[str] = [instance["source"] for instance in outputs_dataset]
176
+ raw_predictions: List[str] = [
177
  instance["raw_prediction"] for instance in outputs_dataset
178
  ]
179
+ predictions: List[str] = [
180
  instance["prediction"] for instance in outputs_dataset
181
  ]
182
  return (prompts, raw_predictions, predictions)
 
274
  raise Exception(
275
  f"The type of the criteria must be 'CriteriaWithOptions', instead it is of type '{type(self.criteria)}'"
276
  )
277
+ criterias: List[CriteriaWithOptions] = [self.criteria] * eval_count
278
  unique_criterias = list({criteria.name for criteria in criterias})
279
  self.logger.info(f"Criteria names are '{', '.join(unique_criterias)}'")
280
  return criterias
 
289
  option_selection_outputs,
290
  selections,
291
  evaluations_count,
292
+ criterias: List[CriteriaWithOptions],
293
+ ) -> List[Dict[str, Any]]:
294
  positional_bias = None
295
  if self.check_positional_bias:
296
  positional_bias = [
 
353
 
354
  def compute(
355
  self,
356
+ references: List[List[str]],
357
+ predictions: List[str],
358
+ task_data: List[Dict[str, Any]],
359
  ) -> dict:
360
  self.logger.info(
361
  f'Starting evaluation with evaluator "{self.evaluator_name}" and provider "{self.inference_engine.get_pretty_print_name()}'
 
545
  f"The type of the criteria must be 'Criteria', instead it is of type '{type(self.criteria)}'"
546
  )
547
 
548
+ criterias: List[Criteria] = [self.criteria] * eval_count
549
 
550
  unique_criterias = list({criteria.name for criteria in criterias})
551
  self.logger.info(f"Criteria names are '{', '.join(unique_criterias)}'")
 
553
 
554
  def get_instance_results(
555
  self,
556
+ instance_predictions: Dict[str, str],
557
  assessment_prompts,
558
  assessment_outputs,
559
  summarization_prompts,
 
728
  all_results["criteria"] = criteria.to_json()
729
  return self.clean_results(all_results)
730
 
731
+ def parse_prediction_to_dict(self, prediction: Union[Dict[str, str], List[str]]):
732
  if isinstance(prediction, list):
733
  return {f"{key + 1}": value for key, value in enumerate(prediction)}
734
 
 
740
  )
741
 
742
  def convert_predictions_to_dicts(
743
+ self, predictions: Union[List[Dict[str, str]], List[str]]
744
  ):
745
  return [self.parse_prediction_to_dict(prediction) for prediction in predictions]
746
 
747
  def compute(
748
  self,
749
+ references: List[List[str]],
750
+ predictions: Union[List[Dict[str, str]], List[str]],
751
+ task_data: List[Dict[str, str]],
752
  ) -> dict:
753
  self.logger.info(
754
  f'Starting evaluation with evaluator "{self.evaluator_name}" and provider {self.inference_engine.get_pretty_print_name()}'
 
775
  f"The evaluation will perform {sum(contests_count_list) * [1,2][self.check_positional_bias]} ({' + '.join([f'{c * [1,2][self.check_positional_bias]}' for c in contests_count_list])}) pairwise comparisons"
776
  )
777
 
778
+ response_pairs_list: List[List[List[str]]] = []
779
+ option_pairs_list: List[List[List[str]]] = []
780
  predictions_names = set(predictions[0].keys())
781
  for i, combination_indexes in enumerate(combination_indexes_list):
782
  instance_predictions = predictions[i]
 
786
  f"The set of prediction names is different between instance 0 and instance {i}. In prediction 0, it is {sorted(predictions_names)}. In prediction {i}, it is {sorted(instance_predictions_names)}. Make sure the same number of predictions is passed for all instances."
787
  )
788
 
789
+ response_pairs: List[List[str]] = []
790
+ option_pairs: List[List[str]] = []
791
  for combination in combination_indexes:
792
  (idx_1, idx_2) = combination
793
  response_name_1 = instance_predictions_names[idx_1]
llm_as_judge_constants.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  from enum import Enum
3
- from typing import Optional
4
 
5
  from .artifact import Artifact
6
  from .inference import (
@@ -36,15 +36,15 @@ class Criteria(Artifact):
36
 
37
 
38
  class CriteriaWithOptions(Criteria):
39
- options: list[CriteriaOption]
40
- option_map: Optional[dict[str, float]] = None
41
 
42
  @staticmethod
43
  def from_jsons(s: str):
44
  return CriteriaWithOptions.from_obj(json.loads(s))
45
 
46
  @staticmethod
47
- def from_obj(criteria_dict: dict):
48
  return CriteriaWithOptions(
49
  name=criteria_dict["name"],
50
  description=criteria_dict["description"],
@@ -132,7 +132,7 @@ PROVIDER_TO_STRATEGY = {
132
 
133
  class EvaluatorMetadata:
134
  name: EvaluatorNameEnum
135
- providers: list[ModelProviderEnum]
136
 
137
  def __init__(self, name, providers):
138
  self.name = name
 
1
  import json
2
  from enum import Enum
3
+ from typing import Dict, List, Optional
4
 
5
  from .artifact import Artifact
6
  from .inference import (
 
36
 
37
 
38
  class CriteriaWithOptions(Criteria):
39
+ options: List[CriteriaOption]
40
+ option_map: Optional[Dict[str, float]] = None
41
 
42
  @staticmethod
43
  def from_jsons(s: str):
44
  return CriteriaWithOptions.from_obj(json.loads(s))
45
 
46
  @staticmethod
47
+ def from_obj(criteria_dict: Dict):
48
  return CriteriaWithOptions(
49
  name=criteria_dict["name"],
50
  description=criteria_dict["description"],
 
132
 
133
  class EvaluatorMetadata:
134
  name: EvaluatorNameEnum
135
+ providers: List[ModelProviderEnum]
136
 
137
  def __init__(self, name, providers):
138
  self.name = name
llm_as_judge_utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from .llm_as_judge_constants import (
2
  EVALUATORS_METADATA,
3
  MODEL_RENAMINGS,
@@ -7,7 +9,7 @@ from .llm_as_judge_constants import (
7
  )
8
 
9
 
10
- def get_parsed_context(context: dict[str, str]):
11
  return (
12
  "\n".join([f"{key}: {value}" for key, value in context.items()])
13
  if len(context) > 1
 
1
+ from typing import Dict
2
+
3
  from .llm_as_judge_constants import (
4
  EVALUATORS_METADATA,
5
  MODEL_RENAMINGS,
 
9
  )
10
 
11
 
12
+ def get_parsed_context(context: Dict[str, str]):
13
  return (
14
  "\n".join([f"{key}: {value}" for key, value in context.items()])
15
  if len(context) > 1
loaders.py CHANGED
@@ -41,6 +41,7 @@ from tempfile import TemporaryDirectory
41
  from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
42
 
43
  import pandas as pd
 
44
  from datasets import load_dataset as hf_load_dataset
45
  from huggingface_hub import HfApi
46
  from tqdm import tqdm
@@ -51,7 +52,7 @@ from .logging_utils import get_logger
51
  from .operator import SourceOperator
52
  from .operators import Set
53
  from .settings_utils import get_settings
54
- from .stream import DynamicStream, MultiStream
55
  from .type_utils import isoftype
56
  from .utils import LRUCache
57
 
@@ -122,7 +123,7 @@ class Loader(SourceOperator):
122
  )
123
  return operator(multi_stream)
124
 
125
- def sef_default_data_classification(
126
  self, default_data_classification_policy, additional_info
127
  ):
128
  if self.data_classification_policy is None:
@@ -162,23 +163,24 @@ class LoadHF(Loader):
162
  and it can filter datasets upon loading.
163
 
164
  Args:
165
- path: The path or identifier of the dataset on the HuggingFace Hub.
166
-
167
- name: An optional dataset name.
168
-
169
- data_dir: Optional directory to store downloaded data.
170
-
171
- split: Optional specification of which split to load.
172
-
173
- data_files: Optional specification of particular data files to load.
174
-
175
- revision: Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
176
-
177
- streaming (bool): indicating if streaming should be used.
178
-
179
- filtering_lambda: A lambda function for filtering the data after loading.
180
-
181
- num_proc (int): Optional integer to specify the number of processes to use for parallel dataset loading.
 
182
 
183
  Example:
184
  Loading glue's mrpc dataset
@@ -278,40 +280,22 @@ class LoadHF(Loader):
278
  for split in dataset.keys():
279
  dataset[split] = dataset[split].to_iterable_dataset()
280
  else:
281
- dataset = {self.split: dataset}
282
-
283
- if self.filtering_lambda is not None:
284
- dataset = self.filter_load(dataset)
285
 
286
  return dataset
287
 
288
- def split_limited_load(self, dataset, split_name):
289
- yield from itertools.islice(dataset[split_name], self.get_limit())
290
-
291
- def limited_load(self, dataset):
292
- self.log_limited_loading()
293
- return MultiStream(
294
- {
295
- name: DynamicStream(
296
- generator=self.split_limited_load,
297
- gen_kwargs={"dataset": dataset, "split_name": name},
298
- )
299
- for name in dataset.keys()
300
- }
301
- )
302
-
303
  def _maybe_set_classification_policy(self):
304
  if os.path.exists(self.path):
305
- self.sef_default_data_classification(
306
  ["proprietary"], "when loading from local files"
307
  )
308
  else:
309
- self.sef_default_data_classification(
310
  ["public"],
311
  None, # No warning when loading from public hub
312
  )
313
 
314
- def load_iterables(self):
315
  try:
316
  dataset = self.stream_dataset()
317
  except (
@@ -319,8 +303,15 @@ class LoadHF(Loader):
319
  ): # streaming is not supported for zipped files so we load without streaming
320
  dataset = self.load_dataset()
321
 
 
 
 
322
  if self.get_limit() is not None:
323
- return self.limited_load(dataset=dataset)
 
 
 
 
324
 
325
  return dataset
326
 
@@ -352,7 +343,7 @@ class LoadCSV(Loader):
352
  sep: str = ","
353
 
354
  def _maybe_set_classification_policy(self):
355
- self.sef_default_data_classification(
356
  ["proprietary"], "when loading from local files"
357
  )
358
 
@@ -365,9 +356,7 @@ class LoadCSV(Loader):
365
  file_path, nrows=self.get_limit(), sep=self.sep
366
  ).to_dict("records")
367
  else:
368
- iterables[split_name] = pd.read_csv(file_path, sep=self.sep).to_dict(
369
- "records"
370
- )
371
  return iterables
372
 
373
 
@@ -475,14 +464,22 @@ class LoadFromIBMCloud(Loader):
475
  3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
476
 
477
  Args:
478
- endpoint_url_env: Environment variable name for the IBM Cloud endpoint URL.
479
- aws_access_key_id_env: Environment variable name for the AWS access key ID.
480
- aws_secret_access_key_env: Environment variable name for the AWS secret access key.
481
- bucket_name: Name of the S3 bucket from which to load data.
482
- data_dir: Optional directory path within the bucket.
483
- data_files: Union type allowing either a list of file names or a mapping of splits to file names.
484
- data_field: The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
485
- caching: Bool indicating if caching is enabled to avoid re-downloading data.
 
 
 
 
 
 
 
 
486
 
487
  Example:
488
  Loading from IBM Cloud
@@ -578,7 +575,7 @@ class LoadFromIBMCloud(Loader):
578
  raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
579
 
580
  def _maybe_set_classification_policy(self):
581
- self.sef_default_data_classification(
582
  ["proprietary"], "when loading from IBM COS"
583
  )
584
 
@@ -729,7 +726,7 @@ class LoadFromDictionary(Loader):
729
  )
730
 
731
  def _maybe_set_classification_policy(self):
732
- self.sef_default_data_classification(
733
  ["proprietary"], "when loading from python dictionary"
734
  )
735
 
@@ -744,25 +741,24 @@ class LoadFromHFSpace(LoadHF):
744
  from the given space and then reads them as a HuggingFace Dataset.
745
 
746
  Args:
747
- space_name (str): Name of the HuggingFace Space to be accessed.
748
-
749
- data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
750
- paths to files within a given repository. If given as a mapping, paths should
751
- be values, while keys should represent the type of respective files
752
- (training, testing etc.).
753
-
754
- path (str, optional): Absolute path to a directory where data should be downloaded.
755
-
756
- revision (str, optional): ID of a Git branch or commit to be used. By default, it is
757
- set to None, thus data is downloaded from the main branch of the accessed
758
- repository.
759
-
760
- use_token (bool, optional): Whether a token is used for authentication when accessing
761
- the HuggingFace Space. If necessary, the token is read from the HuggingFace
762
- config folder.
763
-
764
- token_env (str, optional): Key of an env variable which value will be used for
765
- authentication when accessing the HuggingFace Space - if necessary.
766
 
767
  Example:
768
  Loading from a HuggingFace Space
@@ -910,7 +906,7 @@ class LoadFromHFSpace(LoadHF):
910
  )
911
 
912
  def _maybe_set_classification_policy(self):
913
- self.sef_default_data_classification(
914
  ["public"], "when loading from Huggingface spaces"
915
  )
916
 
 
41
  from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
42
 
43
  import pandas as pd
44
+ from datasets import IterableDatasetDict
45
  from datasets import load_dataset as hf_load_dataset
46
  from huggingface_hub import HfApi
47
  from tqdm import tqdm
 
52
  from .operator import SourceOperator
53
  from .operators import Set
54
  from .settings_utils import get_settings
55
+ from .stream import MultiStream
56
  from .type_utils import isoftype
57
  from .utils import LRUCache
58
 
 
123
  )
124
  return operator(multi_stream)
125
 
126
+ def set_default_data_classification(
127
  self, default_data_classification_policy, additional_info
128
  ):
129
  if self.data_classification_policy is None:
 
163
  and it can filter datasets upon loading.
164
 
165
  Args:
166
+ path:
167
+ The path or identifier of the dataset on the HuggingFace Hub.
168
+ name:
169
+ An optional dataset name.
170
+ data_dir:
171
+ Optional directory to store downloaded data.
172
+ split:
173
+ Optional specification of which split to load.
174
+ data_files:
175
+ Optional specification of particular data files to load.
176
+ revision:
177
+ Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
178
+ streaming (bool):
179
+ indicating if streaming should be used.
180
+ filtering_lambda (str, optional):
181
+ A lambda function for filtering the data after loading.
182
+ num_proc (int, optional):
183
+ Specifies the number of processes to use for parallel dataset loading.
184
 
185
  Example:
186
  Loading glue's mrpc dataset
 
280
  for split in dataset.keys():
281
  dataset[split] = dataset[split].to_iterable_dataset()
282
  else:
283
+ dataset = {self.split: dataset.to_iterable_dataset()}
 
 
 
284
 
285
  return dataset
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  def _maybe_set_classification_policy(self):
288
  if os.path.exists(self.path):
289
+ self.set_default_data_classification(
290
  ["proprietary"], "when loading from local files"
291
  )
292
  else:
293
+ self.set_default_data_classification(
294
  ["public"],
295
  None, # No warning when loading from public hub
296
  )
297
 
298
+ def load_iterables(self) -> IterableDatasetDict:
299
  try:
300
  dataset = self.stream_dataset()
301
  except (
 
303
  ): # streaming is not supported for zipped files so we load without streaming
304
  dataset = self.load_dataset()
305
 
306
+ if self.filtering_lambda is not None:
307
+ dataset = self.filter_load(dataset)
308
+
309
  if self.get_limit() is not None:
310
+ self.log_limited_loading()
311
+ return {
312
+ split_name: dataset[split_name].take(self.get_limit())
313
+ for split_name in dataset
314
+ }
315
 
316
  return dataset
317
 
 
343
  sep: str = ","
344
 
345
  def _maybe_set_classification_policy(self):
346
+ self.set_default_data_classification(
347
  ["proprietary"], "when loading from local files"
348
  )
349
 
 
356
  file_path, nrows=self.get_limit(), sep=self.sep
357
  ).to_dict("records")
358
  else:
359
+ iterables[split_name] = pd.read_csv(file_path).to_dict("records")
 
 
360
  return iterables
361
 
362
 
 
464
  3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
465
 
466
  Args:
467
+ endpoint_url_env:
468
+ Environment variable name for the IBM Cloud endpoint URL.
469
+ aws_access_key_id_env:
470
+ Environment variable name for the AWS access key ID.
471
+ aws_secret_access_key_env:
472
+ Environment variable name for the AWS secret access key.
473
+ bucket_name:
474
+ Name of the S3 bucket from which to load data.
475
+ data_dir:
476
+ Optional directory path within the bucket.
477
+ data_files:
478
+ Union type allowing either a list of file names or a mapping of splits to file names.
479
+ data_field:
480
+ The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
481
+ caching (bool):
482
+ indicating if caching is enabled to avoid re-downloading data.
483
 
484
  Example:
485
  Loading from IBM Cloud
 
575
  raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
576
 
577
  def _maybe_set_classification_policy(self):
578
+ self.set_default_data_classification(
579
  ["proprietary"], "when loading from IBM COS"
580
  )
581
 
 
726
  )
727
 
728
  def _maybe_set_classification_policy(self):
729
+ self.set_default_data_classification(
730
  ["proprietary"], "when loading from python dictionary"
731
  )
732
 
 
741
  from the given space and then reads them as a HuggingFace Dataset.
742
 
743
  Args:
744
+ space_name (str):
745
+ Name of the HuggingFace Space to be accessed.
746
+ data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]):
747
+ Relative paths to files within a given repository. If given as a mapping,
748
+ paths should be values, while keys should represent the type of respective files
749
+ (training, testing etc.).
750
+ path (str, optional):
751
+ Absolute path to a directory where data should be downloaded.
752
+ revision (str, optional):
753
+ ID of a Git branch or commit to be used. By default, it is set to None,
754
+ thus data is downloaded from the main branch of the accessed repository.
755
+ use_token (bool, optional):
756
+ Whether a token is used for authentication when accessing
757
+ the HuggingFace Space. If necessary, the token is read from the HuggingFace
758
+ config folder.
759
+ token_env (str, optional):
760
+ Key of an env variable which value will be used for
761
+ authentication when accessing the HuggingFace Space - if necessary.
 
762
 
763
  Example:
764
  Loading from a HuggingFace Space
 
906
  )
907
 
908
  def _maybe_set_classification_policy(self):
909
+ self.set_default_data_classification(
910
  ["public"], "when loading from Huggingface spaces"
911
  )
912
 
metric_utils.py CHANGED
@@ -353,13 +353,11 @@ UNITXT_METRIC_SCHEMA = Features(
353
  class GlobalScores(dict):
354
  """GlobalScores is a dictionary-based class designed to handle and transform metric results into a structured format.
355
 
356
- Attributes:
357
- score (float): The main score value.
358
- score_name (str): The name of the main score.
359
-
360
- Methods:
361
- to_df():
362
- Transforms the dictionary of results into a pandas DataFrame with score_name as the index,
363
  """
364
 
365
  @property
@@ -550,12 +548,11 @@ class GroupsScores(dict):
550
  This class provides a property to summarize the scores and a custom
551
  string representation for pretty-printing.
552
 
553
- Attributes:
554
- summary (property): A property to get a summary of the group scores.
555
  """
556
 
557
  @property
558
  def summary(self):
 
559
  data = self
560
  # Desired metric columns
561
  metric_cols = [
 
353
  class GlobalScores(dict):
354
  """GlobalScores is a dictionary-based class designed to handle and transform metric results into a structured format.
355
 
356
+ Args:
357
+ score (float):
358
+ The main score value.
359
+ score_name (str):
360
+ The name of the main score.
 
 
361
  """
362
 
363
  @property
 
548
  This class provides a property to summarize the scores and a custom
549
  string representation for pretty-printing.
550
 
 
 
551
  """
552
 
553
  @property
554
  def summary(self):
555
+ """A property to get a summary of the group scores."""
556
  data = self
557
  # Desired metric columns
558
  metric_cols = [
metrics.py CHANGED
@@ -48,7 +48,7 @@ from .random_utils import get_seed
48
  from .settings_utils import get_settings
49
  from .stream import MultiStream, Stream
50
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
51
- from .utils import deep_copy
52
 
53
  logger = get_logger()
54
  settings = get_settings()
@@ -992,7 +992,17 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
992
  reference_field: str = NonPositionalField(default="references")
993
  prediction_field: str = NonPositionalField(default="prediction")
994
 
995
- def _validate_group_mean_reduction(self, instances: List[dict]):
 
 
 
 
 
 
 
 
 
 
996
  """Ensure that group_mean reduction_map is properly formatted.
997
 
998
  Example: Apply the variance (np.var) to group Accuracy instance scores. This class would be specified as follows:
@@ -1042,17 +1052,6 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1042
  1 'How do I repair my engine?' 'paraphrase'
1043
  2 'Why are ants eating my food?' 'original'
1044
  """
1045
- # instances need to all have task_data field with field group_id
1046
- assert all(
1047
- "task_data" in instance for instance in instances
1048
- ), "each instance must have an task_data field"
1049
- assert all(
1050
- isinstance(instance["task_data"], dict) for instance in instances
1051
- ), "each instance must have an task_data field that is a dict"
1052
- assert all(
1053
- "group_id" in instance["task_data"] for instance in instances
1054
- ), "each instance task_data dict must have a key group_id"
1055
-
1056
  # validate the reduction_map
1057
  assert (
1058
  "group_mean" in self.reduction_map
@@ -1081,16 +1080,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1081
  if "score_fields" in fields:
1082
  assert isinstance(fields["score_fields"], list)
1083
 
1084
- # for aggregation functions that use the subgroup_column (expect a dict of lists), check that
1085
- # this field exists
1086
- if self.subgroup_column is not None:
1087
- assert all(
1088
- self.subgroup_column in instance["task_data"] for instance in instances
1089
- ), f"each instance task_data dict must have a key {self.subgroup_column}"
1090
-
1091
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1092
- instances = self.compute_instance_scores(stream)
1093
- global_score = {"num_of_instances": len(instances)}
1094
  for reduction_type, reduction_params in self.reduction_map.items():
1095
  assert (
1096
  reduction_type in self.implemented_reductions
@@ -1103,15 +1095,15 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1103
  aggregation_function = self.average_item_scores
1104
  reduction_fields = list(set(reduction_params))
1105
  # no group reduction, so resample instances individually
1106
- scores_to_resample = instances
1107
  elif reduction_type == "max":
1108
  aggregation_function = self.max_item_scores
1109
  reduction_fields = list(set(reduction_params))
1110
  # no group reduction, so resample instances individually
1111
- scores_to_resample = instances
1112
  elif reduction_type == "group_mean":
1113
  aggregation_function = self.average_item_scores
1114
- self._validate_group_mean_reduction(instances=instances)
1115
  reduction_fields = (
1116
  [self.main_score]
1117
  if "score_fields" not in reduction_params
@@ -1127,7 +1119,7 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1127
  scores_to_resample,
1128
  aggregation_function,
1129
  ) = self._set_up_group_mean_aggregation(
1130
- instances,
1131
  reduction_params,
1132
  reduction_fields,
1133
  )
@@ -1168,18 +1160,32 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1168
  )
1169
  global_score.update(confidence_interval)
1170
 
1171
- for instance in instances:
1172
  self.update_and_adjust_global_score(instance, global_score)
1173
- yield from instances
 
 
 
1174
 
1175
  def compute_instance_scores(
1176
  self, stream: Stream, stream_name: Optional[str] = None
1177
  ):
1178
- instances = []
1179
 
1180
  for instance in stream:
1181
  instance = self.verify_instance(instance)
1182
 
 
 
 
 
 
 
 
 
 
 
 
1183
  task_data = instance["task_data"] if "task_data" in instance else {}
1184
 
1185
  if self.reference_field == "references":
@@ -1214,9 +1220,18 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1214
  instance_score, instance["score"]["instance"]
1215
  )
1216
  )
1217
- instances.append(instance)
 
 
 
 
 
 
 
1218
 
1219
- return instances
 
 
1220
 
1221
  def get_group_scores(
1222
  self,
@@ -1228,12 +1243,16 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1228
  """Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
1229
 
1230
  Args:
1231
- instances: List of observation instances with instance-level scores (fields) computed.
1232
- score_names: List of instance score names in each instance to apply the aggregation function.
1233
- group_aggregation_func: Callable aggregation function accepting a list of numeric scores;
 
 
 
1234
  or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
1235
  callable function returns a single score for the group
1236
- prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False
 
1237
  if down the stream such a prepending is expected.
1238
 
1239
  Returns:
@@ -4910,14 +4929,18 @@ class IsCodeMixed(BulkInstanceMetric):
4910
  class MetricsEnsemble(InstanceMetric, ArtifactFetcherMixin):
4911
  """Metrics Ensemble class for creating ensemble of given metrics.
4912
 
4913
- Attributes:
4914
- main_score (str): The main score label used for evaluation.
4915
- metrics (List[Union[Metric, str]]): List of metrics that will be ensemble.
4916
- weights (List[float]): Weight of each the metrics
4917
- InstanceMetric currently allows two reductions:
4918
- reduction_map (Dict[str, List[str]]. Parameter for specifying the redaction method of the global score.
4919
- (see it definition at InstanceMetric class). This class define its default
4920
- value to reduce by the mean of the main score.
 
 
 
 
4921
 
4922
  """
4923
 
 
48
  from .settings_utils import get_settings
49
  from .stream import MultiStream, Stream
50
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
51
+ from .utils import deep_copy, recursive_copy
52
 
53
  logger = get_logger()
54
  settings = get_settings()
 
992
  reference_field: str = NonPositionalField(default="references")
993
  prediction_field: str = NonPositionalField(default="prediction")
994
 
995
+ def _validate_group_mean_task_data(self, instance):
996
+ # instances need to all have task_data field with field group_id
997
+ assert "task_data" in instance, "each instance must have an task_data field"
998
+ assert isinstance(
999
+ instance["task_data"], dict
1000
+ ), "each instance must have an task_data field that is a dict"
1001
+ assert (
1002
+ "group_id" in instance["task_data"]
1003
+ ), "each instance task_data dict must have a key group_id"
1004
+
1005
+ def _validate_group_mean_reduction(self):
1006
  """Ensure that group_mean reduction_map is properly formatted.
1007
 
1008
  Example: Apply the variance (np.var) to group Accuracy instance scores. This class would be specified as follows:
 
1052
  1 'How do I repair my engine?' 'paraphrase'
1053
  2 'Why are ants eating my food?' 'original'
1054
  """
 
 
 
 
 
 
 
 
 
 
 
1055
  # validate the reduction_map
1056
  assert (
1057
  "group_mean" in self.reduction_map
 
1080
  if "score_fields" in fields:
1081
  assert isinstance(fields["score_fields"], list)
1082
 
 
 
 
 
 
 
 
1083
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1084
+ instance_scores = self.compute_instance_scores(stream)
1085
+ global_score = {"num_of_instances": len(instance_scores)}
1086
  for reduction_type, reduction_params in self.reduction_map.items():
1087
  assert (
1088
  reduction_type in self.implemented_reductions
 
1095
  aggregation_function = self.average_item_scores
1096
  reduction_fields = list(set(reduction_params))
1097
  # no group reduction, so resample instances individually
1098
+ scores_to_resample = instance_scores
1099
  elif reduction_type == "max":
1100
  aggregation_function = self.max_item_scores
1101
  reduction_fields = list(set(reduction_params))
1102
  # no group reduction, so resample instances individually
1103
+ scores_to_resample = instance_scores
1104
  elif reduction_type == "group_mean":
1105
  aggregation_function = self.average_item_scores
1106
+ self._validate_group_mean_reduction()
1107
  reduction_fields = (
1108
  [self.main_score]
1109
  if "score_fields" not in reduction_params
 
1119
  scores_to_resample,
1120
  aggregation_function,
1121
  ) = self._set_up_group_mean_aggregation(
1122
+ instance_scores,
1123
  reduction_params,
1124
  reduction_fields,
1125
  )
 
1160
  )
1161
  global_score.update(confidence_interval)
1162
 
1163
+ for instance in instance_scores:
1164
  self.update_and_adjust_global_score(instance, global_score)
1165
+
1166
+ for i, instance in enumerate(stream):
1167
+ instance["score"] = recursive_copy(instance_scores[i]["score"])
1168
+ yield instance
1169
 
1170
  def compute_instance_scores(
1171
  self, stream: Stream, stream_name: Optional[str] = None
1172
  ):
1173
+ instance_scores = []
1174
 
1175
  for instance in stream:
1176
  instance = self.verify_instance(instance)
1177
 
1178
+ if "group_mean" in self.reduction_map:
1179
+ self._validate_group_mean_task_data(instance)
1180
+
1181
+ # for aggregation functions that use the subgroup_column (expect a dict of lists), check that
1182
+ # this field exists
1183
+ if self.subgroup_column is not None:
1184
+ assert (
1185
+ "task_data" in instance
1186
+ and self.subgroup_column in instance["task_data"]
1187
+ ), f"each instance task_data dict must have a key {self.subgroup_column}"
1188
+
1189
  task_data = instance["task_data"] if "task_data" in instance else {}
1190
 
1191
  if self.reference_field == "references":
 
1220
  instance_score, instance["score"]["instance"]
1221
  )
1222
  )
1223
+ task_data = {}
1224
+ if "task_data" in instance:
1225
+ if "group_id" in instance["task_data"]:
1226
+ task_data["group_id"] = instance["task_data"]["group_id"]
1227
+ if self.subgroup_column in instance["task_data"]:
1228
+ task_data[self.subgroup_column] = instance["task_data"][
1229
+ self.subgroup_column
1230
+ ]
1231
 
1232
+ instance_scores.append({"score": instance["score"], "task_data": task_data})
1233
+
1234
+ return instance_scores
1235
 
1236
  def get_group_scores(
1237
  self,
 
1243
  """Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
1244
 
1245
  Args:
1246
+ instances (list):
1247
+ List of observation instances with instance-level scores (fields) computed.
1248
+ score_names (list):
1249
+ List of instance score names in each instance to apply the aggregation function.
1250
+ group_aggregation_func (Callable):
1251
+ aggregation function accepting a list of numeric scores;
1252
  or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
1253
  callable function returns a single score for the group
1254
+ prepend_score_prefix (bool):
1255
+ if True - prepend the score_prefix to the score names in the returned dicts. Set to False
1256
  if down the stream such a prepending is expected.
1257
 
1258
  Returns:
 
4929
  class MetricsEnsemble(InstanceMetric, ArtifactFetcherMixin):
4930
  """Metrics Ensemble class for creating ensemble of given metrics.
4931
 
4932
+ Args:
4933
+ main_score (str):
4934
+ The main score label used for evaluation.
4935
+ metrics (List[Union[Metric, str]]):
4936
+ List of metrics that will be ensemble.
4937
+ weights (List[float]):
4938
+ Weight of each the metrics
4939
+ reduction_map (Dict[str, List[str]]):
4940
+ Specifies the redaction method of the global score.
4941
+ InstanceMetric currently allows two reductions
4942
+ (see it definition at InstanceMetric class).
4943
+ This class define its default value to reduce by the mean of the main score.
4944
 
4945
  """
4946
 
operator.py CHANGED
@@ -222,11 +222,11 @@ class SourceOperator(MultiStreamOperator):
222
 
223
  A source operator is responsible for generating the data stream from some source, such as a database or a file.
224
  This is the starting point of a stream processing pipeline.
225
- The `SourceOperator` class is a type of `SourceOperator`, which is a special type of `StreamingOperator`
226
  that generates an output stream but does not take any input streams.
227
 
228
- When called, a `SourceOperator` invokes its `process` method, which should be implemented by all subclasses
229
- to generate the required `MultiStream`.
230
 
231
  """
232
 
@@ -247,9 +247,14 @@ class SourceOperator(MultiStreamOperator):
247
  class StreamInitializerOperator(SourceOperator):
248
  """A class representing a stream initializer operator in the streaming system.
249
 
250
- A stream initializer operator is a special type of `SourceOperator` that is capable of taking parameters during the stream generation process. This can be useful in situations where the stream generation process needs to be customized or configured based on certain parameters.
 
 
 
251
 
252
- When called, a `StreamInitializerOperator` invokes its `process` method, passing any supplied arguments and keyword arguments. The `process` method should be implemented by all subclasses to generate the required `MultiStream` based on the given arguments and keyword arguments.
 
 
253
 
254
  """
255
 
@@ -278,11 +283,12 @@ def instance_result(result_stream):
278
  class StreamOperator(MultiStreamOperator):
279
  """A class representing a single-stream operator in the streaming system.
280
 
281
- A single-stream operator is a type of `MultiStreamOperator` that operates on individual
282
- `Stream` objects within a `MultiStream`. It iterates through each `Stream` in the `MultiStream`
283
- and applies the `process` method.
284
- The `process` method should be implemented by subclasses to define the specific operations
285
- to be performed on each `Stream`.
 
286
 
287
  """
288
 
@@ -353,13 +359,15 @@ class SingleStreamOperator(StreamOperator):
353
  class PagedStreamOperator(StreamOperator):
354
  """A class representing a paged-stream operator in the streaming system.
355
 
356
- A paged-stream operator is a type of `StreamOperator` that operates on a page of instances
357
- in a `Stream` at a time, where a page is a subset of instances.
358
- The `process` method should be implemented by subclasses to define the specific operations
359
  to be performed on each page.
360
 
361
  Args:
362
- page_size (int): The size of each page in the stream. Defaults to 1000.
 
 
363
  """
364
 
365
  page_size: int = 1000
@@ -393,7 +401,12 @@ class PagedStreamOperator(StreamOperator):
393
  class SingleStreamReducer(StreamingOperator):
394
  """A class representing a single-stream reducer in the streaming system.
395
 
396
- A single-stream reducer is a type of `StreamingOperator` that operates on individual `Stream` objects within a `MultiStream` and reduces each `Stream` to a single output value. The `process` method should be implemented by subclasses to define the specific reduction operation to be performed on each `Stream`.
 
 
 
 
 
397
  """
398
 
399
  def __call__(self, multi_stream: Optional[MultiStream] = None) -> Dict[str, Any]:
@@ -412,7 +425,10 @@ class SingleStreamReducer(StreamingOperator):
412
  class InstanceOperator(StreamOperator):
413
  """A class representing a stream instance operator in the streaming system.
414
 
415
- A stream instance operator is a type of `StreamOperator` that operates on individual instances within a `Stream`. It iterates through each instance in the `Stream` and applies the `process` method. The `process` method should be implemented by subclasses to define the specific operations to be performed on each instance.
 
 
 
416
  """
417
 
418
  def _process_stream(
@@ -449,7 +465,8 @@ class InstanceOperator(StreamOperator):
449
  class InstanceOperatorValidator(InstanceOperator):
450
  """A class representing a stream instance operator validator in the streaming system.
451
 
452
- A stream instance operator validator is a type of `InstanceOperator` that includes a validation step. It operates on individual instances within a `Stream` and validates the result of processing each instance.
 
453
  """
454
 
455
  @abstractmethod
 
222
 
223
  A source operator is responsible for generating the data stream from some source, such as a database or a file.
224
  This is the starting point of a stream processing pipeline.
225
+ The ``SourceOperator`` class is a type of ``MultiStreamOperator``, which is a special type of ``StreamingOperator``
226
  that generates an output stream but does not take any input streams.
227
 
228
+ When called, a ``SourceOperator`` invokes its ``process`` method, which should be implemented by all subclasses
229
+ to generate the required ``MultiStream``.
230
 
231
  """
232
 
 
247
  class StreamInitializerOperator(SourceOperator):
248
  """A class representing a stream initializer operator in the streaming system.
249
 
250
+ A stream initializer operator is a special type of ``SourceOperator`` that is capable
251
+ of taking parameters during the stream generation process.
252
+ This can be useful in situations where the stream generation process needs to be
253
+ customized or configured based on certain parameters.
254
 
255
+ When called, a ``StreamInitializerOperator`` invokes its ``process`` method, passing any supplied
256
+ arguments and keyword arguments. The ``process`` method should be implemented by all subclasses
257
+ to generate the required ``MultiStream`` based on the given arguments and keyword arguments.
258
 
259
  """
260
 
 
283
  class StreamOperator(MultiStreamOperator):
284
  """A class representing a single-stream operator in the streaming system.
285
 
286
+ A single-stream operator is a type of ``MultiStreamOperator`` that operates on individual
287
+ ``Stream`` objects within a ``MultiStream``. It iterates through each ``Stream`` in the ``MultiStream``
288
+ and applies the ``process`` method.
289
+
290
+ The ``process`` method should be implemented by subclasses to define the specific operations
291
+ to be performed on each ``Stream``.
292
 
293
  """
294
 
 
359
  class PagedStreamOperator(StreamOperator):
360
  """A class representing a paged-stream operator in the streaming system.
361
 
362
+ A paged-stream operator is a type of ``StreamOperator`` that operates on a page of instances
363
+ in a ``Stream`` at a time, where a page is a subset of instances.
364
+ The ``process`` method should be implemented by subclasses to define the specific operations
365
  to be performed on each page.
366
 
367
  Args:
368
+ page_size (int):
369
+ The size of each page in the stream. Defaults to 1000.
370
+
371
  """
372
 
373
  page_size: int = 1000
 
401
  class SingleStreamReducer(StreamingOperator):
402
  """A class representing a single-stream reducer in the streaming system.
403
 
404
+ A single-stream reducer is a type of ``StreamingOperator`` that operates on individual
405
+ ``Stream`` objects within a ``MultiStream`` and reduces each ``Stream`` to a single output value.
406
+
407
+ The ``process`` method should be implemented by subclasses to define the specific reduction operation
408
+ to be performed on each ``Stream``.
409
+
410
  """
411
 
412
  def __call__(self, multi_stream: Optional[MultiStream] = None) -> Dict[str, Any]:
 
425
  class InstanceOperator(StreamOperator):
426
  """A class representing a stream instance operator in the streaming system.
427
 
428
+ A stream instance operator is a type of ``StreamOperator`` that operates on individual instances
429
+ within a ``Stream``. It iterates through each instance in the ``Stream`` and applies the ``process`` method.
430
+ The ``process`` method should be implemented by subclasses to define the specific operations
431
+ to be performed on each instance.
432
  """
433
 
434
  def _process_stream(
 
465
  class InstanceOperatorValidator(InstanceOperator):
466
  """A class representing a stream instance operator validator in the streaming system.
467
 
468
+ A stream instance operator validator is a type of ``InstanceOperator`` that includes a validation step.
469
+ It operates on individual instances within a ``Stream`` and validates the result of processing each instance.
470
  """
471
 
472
  @abstractmethod
operators.py CHANGED
@@ -66,6 +66,7 @@ from .artifact import Artifact, fetch_artifact
66
  from .dataclass import NonPositionalField, OptionalField
67
  from .deprecation_utils import deprecation
68
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
 
69
  from .operator import (
70
  InstanceOperator,
71
  MultiStream,
@@ -81,7 +82,7 @@ from .operator import (
81
  )
82
  from .random_utils import new_random_generator
83
  from .settings_utils import get_settings
84
- from .stream import DynamicStream, ListStream, Stream
85
  from .text_utils import nested_tuple_to_string
86
  from .type_utils import isoftype
87
  from .utils import (
@@ -132,23 +133,24 @@ class IterableSource(SourceOperator):
132
  class MapInstanceValues(InstanceOperator):
133
  """A class used to map instance values into other values.
134
 
135
- This class is a type of InstanceOperator,
136
  it maps values of instances in a stream using predefined mappers.
137
 
138
- Attributes:
139
- mappers (Dict[str, Dict[str, Any]]): The mappers to use for mapping instance values.
140
- Keys are the names of the fields to undergo mapping, and values are dictionaries
141
- that define the mapping from old values to new values.
142
- Note that mapped values are defined by their string representation, so mapped values
143
- are converted to strings before being looked up in the mappers.
144
-
145
- strict (bool): If True, the mapping is applied strictly. That means if a value
146
- does not exist in the mapper, it will raise a KeyError. If False, values
147
- that are not present in the mapper are kept as they are.
148
-
149
- process_every_value (bool): If True, all fields to be mapped should be lists, and the mapping
150
- is to be applied to their individual elements. If False, mapping is only applied to a field
151
- containing a single value.
 
152
 
153
  Examples:
154
  ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
@@ -335,23 +337,23 @@ class InstanceFieldOperator(InstanceOperator):
335
  """A general stream instance operator that processes the values of a field (or multiple ones).
336
 
337
  Args:
338
- field (Optional[str]): The field to process, if only a single one is passed. Defaults to None
339
-
340
- to_field (Optional[str]): Field name to save result into, if only one field is processed, if None is passed the
341
- operation would happen in-place and its result would replace the value of ``field``. Defaults to None
342
-
343
- field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]): Mapping from names of fields to process,
344
- to names of fields to save the results into. Inner List, if used, should be of length 2.
345
- | A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
346
- is mapped to the field.
347
- | When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
348
- in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
349
- order. The end result might depend on that order if either (1) two different fields are mapped to the same
350
- to_field, or (2) a field shows both as a key and as a value in different mappings.
351
- | The operator throws an AssertionError in either of these cases.
352
- | field_to_field defaults to None
353
-
354
- process_every_value (bool): Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
355
 
356
  Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
357
  prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
@@ -806,10 +808,16 @@ class TakeByField(InstanceOperator):
806
 
807
 
808
  class Perturb(FieldOperator):
809
- """Slightly perturbs the contents of 'field'. Could be Handy for imitating prediction from given target.
810
 
811
- When task was classification, argument 'select_from' can be used to list the other potential classes, as a
812
  relevant perturbation
 
 
 
 
 
 
813
  """
814
 
815
  select_from: List[Any] = []
@@ -937,12 +945,13 @@ class CastFields(InstanceOperator):
937
  """Casts specified fields to specified types.
938
 
939
  Args:
940
- fields (Dict[str, str]): A dictionary mapping field names to the names of the types to cast the fields to.
941
- e.g: "int", "str", "float", "bool". Basic names of types
942
-
943
- defaults (Dict[str, object]): A dictionary mapping field names to default values for cases of casting failure.
944
-
945
- process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
 
946
 
947
  Example:
948
  .. code-block:: python
@@ -1268,16 +1277,19 @@ class FilterByExpression(StreamOperator, ComputeExpressionMixin):
1268
  Raises an error if a field participating in the specified condition is missing from the instance
1269
 
1270
  Args:
1271
- expression (str): a condition over fields of the instance, to be processed by python's eval()
1272
- imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
1273
- error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
 
 
 
1274
 
1275
  Examples:
1276
- FilterByExpression(expression = "a > 4") will yield only instances where "a">4
1277
- FilterByExpression(expression = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1278
- FilterByExpression(expression = "a in [4, 8]") will yield only instances where "a" is 4 or 8
1279
- FilterByExpression(expression = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8
1280
- FilterByExpression(expression = "a['b'] not in [4, 8]") will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
1281
  """
1282
 
1283
  error_on_filtered_all: bool = True
@@ -1635,23 +1647,17 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1635
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1636
  from .metrics import Metric, MetricsList
1637
 
1638
- # Number of instances in input stream is assumed to be small. This is why
1639
- # each metric consumes all of them and lays them in its main memory, and even generates
1640
- # some 1000 copies thereof for the sake of CI.
1641
- # So we start with deep copying here, to make a 'frozen' status of the stream, having
1642
- # passed the preprocess_steps of the task, and inference, and now getting to be evaluated,
1643
- # a frozen status to be fed into each of the metrics listed in metric_field,
1644
- # so that the evaluation of one does not affect the evaluation of another
1645
- # (typically, affecting via change of instance as part of
1646
- # preprocess_steps of MetricPipeline, as illustrated in docs/adding_metrics/Using Metric Pipelines).
1647
-
1648
- instances_upon_entrance_to_metrics_evaluations = []
1649
- for instance in stream:
1650
- instances_upon_entrance_to_metrics_evaluations.append(
1651
- recursive_copy(instance)
1652
- )
1653
 
1654
- first_instance = instances_upon_entrance_to_metrics_evaluations[0]
 
 
 
1655
 
1656
  metric_names = first_instance.get(self.metric_field, [])
1657
  if not metric_names:
@@ -1680,26 +1686,28 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1680
  # by the first listed metric (as desired).
1681
  metrics_list = list(reversed(metrics_list))
1682
 
1683
- for metric in metrics_list:
1684
  if not self.calc_confidence_intervals:
1685
  metric.disable_confidence_interval_calculation()
1686
- multi_stream = MultiStream(
1687
- {
1688
- "tmp": ListStream(
1689
- instances_list=instances_upon_entrance_to_metrics_evaluations,
1690
- copying=True, # ensures deep copy when iterating over instances
1691
- )
1692
- }
1693
- )
1694
- multi_stream = metric(multi_stream)
1695
- for evaluated_instance, freezed_instance in zip(
1696
- multi_stream["tmp"], instances_upon_entrance_to_metrics_evaluations
1697
- ):
1698
- freezed_instance["score"] = recursive_shallow_copy(
1699
- evaluated_instance["score"]
1700
  )
 
 
 
 
 
 
 
 
 
 
1701
 
1702
- yield from instances_upon_entrance_to_metrics_evaluations
1703
 
1704
 
1705
  class MergeStreams(MultiStreamOperator):
@@ -1872,13 +1880,15 @@ class StreamRefiner(StreamOperator):
1872
  input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
1873
  of the leading 'max_instances' of the input stream.
1874
 
1875
- Args: max_instances (int)
1876
- apply_to_streams (optional, list(str)): names of streams to refine.
 
 
1877
 
1878
  Examples:
1879
- when input = [{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}] is fed into
1880
- StreamRefiner(max_instances=4)
1881
- the resulting stream is [{"a": 1},{"a": 2},{"a": 3},{"a": 4}]
1882
  """
1883
 
1884
  max_instances: int = None
@@ -1899,18 +1909,20 @@ class DeterministicBalancer(StreamRefiner):
1899
  When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
1900
  'max_instances'. The total number of discarded instances is as few as possible.
1901
 
1902
- Attributes:
1903
- fields (List[str]): A list of field names to be used in producing the instance's signature.
1904
- max_instances (Optional, int)
 
 
1905
 
1906
  Usage:
1907
- balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)
1908
- balanced_stream = balancer.process(stream)
1909
 
1910
  Example:
1911
- When input [{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}] is fed into
1912
- DeterministicBalancer(fields=["a"])
1913
- the resulting stream will be: [{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]
1914
  """
1915
 
1916
  fields: List[str]
@@ -1947,24 +1959,28 @@ class DeterministicBalancer(StreamRefiner):
1947
  class MinimumOneExamplePerLabelRefiner(StreamRefiner):
1948
  """A class used to return a specified number instances ensuring at least one example per label.
1949
 
1950
- For each instance, a signature value is constructed from the values of the instance in specified input 'fields'.
1951
- MinimumOneExamplePerLabelRefiner takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit. In general, the refiner takes the first elements in the stream that meet the required conditions.
1952
- MinimumOneExamplePerLabelRefiner then shuffles the results to avoid having one instance
1953
  from each class first and then the rest . If max instance is not set, the original stream will be used
1954
 
1955
- Attributes:
1956
- fields (List[str]): A list of field names to be used in producing the instance's signature.
1957
- max_instances (Optional, int): Number of elements to select. Note that max_instances of StreamRefiners that are passed to the recipe (e.g. 'train_refiner'. `test_refiner`) are overridden by the recipe parameters ( `max_train_instances`, `max_test_instances`)
 
 
 
 
1958
 
1959
  Usage:
1960
- balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)
1961
- balanced_stream = balancer.process(stream)
1962
 
1963
  Example:
1964
- When input [{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}] is fed into
1965
- MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)
1966
  the resulting stream will be:
1967
- [{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}] (order may be different)
1968
  """
1969
 
1970
  fields: List[str]
@@ -2022,20 +2038,19 @@ class LengthBalancer(DeterministicBalancer):
2022
  """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
2023
 
2024
  Args:
2025
- segments_boundaries (List[int]): distinct integers sorted in increasing order, that maps a given total length
2026
- into the index of the least of them that exceeds the total length. (If none exceeds -- into one index
2027
- beyond, namely, the length of segments_boundaries)
 
 
 
2028
 
2029
- fields (Optional, List[str])
2030
 
2031
  Example:
2032
- when input [{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}] is fed into
2033
-
2034
- .. code-block::
2035
-
2036
- LengthBalancer(fields=["a"], segments_boundaries=[1])
2037
-
2038
- input instances will be counted and balanced against two categories: empty total length (less than 1), and non-empty.
2039
  """
2040
 
2041
  segments_boundaries: List[int]
@@ -2067,9 +2082,11 @@ class UnexpectedHttpCodeError(Exception):
2067
  class DownloadOperator(SideEffectOperator):
2068
  """Operator for downloading a file from a given URL to a specified local path.
2069
 
2070
- Attributes:
2071
- source (str): URL of the file to be downloaded.
2072
- target (str): Local path where the downloaded file should be saved.
 
 
2073
  """
2074
 
2075
  source: str
@@ -2089,9 +2106,11 @@ class DownloadOperator(SideEffectOperator):
2089
  class ExtractZipFile(SideEffectOperator):
2090
  """Operator for extracting files from a zip archive.
2091
 
2092
- Attributes:
2093
- zip_file (str): Path of the zip file to be extracted.
2094
- target_dir (str): Directory where the contents of the zip file will be extracted.
 
 
2095
  """
2096
 
2097
  zip_file: str
@@ -2105,8 +2124,9 @@ class ExtractZipFile(SideEffectOperator):
2105
  class DuplicateInstances(StreamOperator):
2106
  """Operator which duplicates each instance in stream a given number of times.
2107
 
2108
- Attributes:
2109
- num_duplications (int): How many times each instance should be duplicated (1 means no duplication).
 
2110
  duplication_index_field (Optional[str]):
2111
  If given, then additional field with specified name is added to each duplicated instance,
2112
  which contains id of a given duplication. Defaults to None, so no field is added.
 
66
  from .dataclass import NonPositionalField, OptionalField
67
  from .deprecation_utils import deprecation
68
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
69
+ from .generator_utils import ReusableGenerator
70
  from .operator import (
71
  InstanceOperator,
72
  MultiStream,
 
82
  )
83
  from .random_utils import new_random_generator
84
  from .settings_utils import get_settings
85
+ from .stream import DynamicStream, Stream
86
  from .text_utils import nested_tuple_to_string
87
  from .type_utils import isoftype
88
  from .utils import (
 
133
  class MapInstanceValues(InstanceOperator):
134
  """A class used to map instance values into other values.
135
 
136
+ This class is a type of ``InstanceOperator``,
137
  it maps values of instances in a stream using predefined mappers.
138
 
139
+ Args:
140
+ mappers (Dict[str, Dict[str, Any]]):
141
+ The mappers to use for mapping instance values.
142
+ Keys are the names of the fields to undergo mapping, and values are dictionaries
143
+ that define the mapping from old values to new values.
144
+ Note that mapped values are defined by their string representation, so mapped values
145
+ are converted to strings before being looked up in the mappers.
146
+ strict (bool):
147
+ If True, the mapping is applied strictly. That means if a value
148
+ does not exist in the mapper, it will raise a KeyError. If False, values
149
+ that are not present in the mapper are kept as they are.
150
+ process_every_value (bool):
151
+ If True, all fields to be mapped should be lists, and the mapping
152
+ is to be applied to their individual elements.
153
+ If False, mapping is only applied to a field containing a single value.
154
 
155
  Examples:
156
  ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
 
337
  """A general stream instance operator that processes the values of a field (or multiple ones).
338
 
339
  Args:
340
+ field (Optional[str]):
341
+ The field to process, if only a single one is passed. Defaults to None
342
+ to_field (Optional[str]):
343
+ Field name to save result into, if only one field is processed, if None is passed the
344
+ operation would happen in-place and its result would replace the value of ``field``. Defaults to None
345
+ field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]):
346
+ Mapping from names of fields to process,
347
+ to names of fields to save the results into. Inner List, if used, should be of length 2.
348
+ A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
349
+ is mapped to the field. When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
350
+ in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
351
+ order. The end result might depend on that order if either (1) two different fields are mapped to the same
352
+ to_field, or (2) a field shows both as a key and as a value in different mappings.
353
+ The operator throws an AssertionError in either of these cases. ``field_to_field``
354
+ defaults to None.
355
+ process_every_value (bool):
356
+ Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
357
 
358
  Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
359
  prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
 
808
 
809
 
810
  class Perturb(FieldOperator):
811
+ """Slightly perturbs the contents of ``field``. Could be Handy for imitating prediction from given target.
812
 
813
+ When task was classification, argument ``select_from`` can be used to list the other potential classes, as a
814
  relevant perturbation
815
+
816
+ Args:
817
+ percentage_to_perturb (int):
818
+ the percentage of the instances for which to apply this perturbation. Defaults to 1 (1 percent)
819
+ select_from: List[Any]:
820
+ a list of values to select from, as a perturbation of the field's value. Defaults to [].
821
  """
822
 
823
  select_from: List[Any] = []
 
945
  """Casts specified fields to specified types.
946
 
947
  Args:
948
+ fields (Dict[str, str]):
949
+ A dictionary mapping field names to the names of the types to cast the fields to.
950
+ e.g: "int", "str", "float", "bool". Basic names of types
951
+ defaults (Dict[str, object]):
952
+ A dictionary mapping field names to default values for cases of casting failure.
953
+ process_every_value (bool):
954
+ If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
955
 
956
  Example:
957
  .. code-block:: python
 
1277
  Raises an error if a field participating in the specified condition is missing from the instance
1278
 
1279
  Args:
1280
+ expression (str):
1281
+ a condition over fields of the instance, to be processed by python's eval()
1282
+ imports_list (List[str]):
1283
+ names of imports needed for the eval of the query (e.g. 're', 'json')
1284
+ error_on_filtered_all (bool, optional):
1285
+ If True, raises an error if all instances are filtered out. Defaults to True.
1286
 
1287
  Examples:
1288
+ | ``FilterByExpression(expression = "a > 4")`` will yield only instances where "a">4
1289
+ | ``FilterByExpression(expression = "a <= 4 and b > 5")`` will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1290
+ | ``FilterByExpression(expression = "a in [4, 8]")`` will yield only instances where "a" is 4 or 8
1291
+ | ``FilterByExpression(expression = "a not in [4, 8]")`` will yield only instances where "a" is neither 4 nor 8
1292
+ | ``FilterByExpression(expression = "a['b'] not in [4, 8]")`` will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
1293
  """
1294
 
1295
  error_on_filtered_all: bool = True
 
1647
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1648
  from .metrics import Metric, MetricsList
1649
 
1650
+ def update_scores_of_stream_instances(
1651
+ stream: Stream, scores: List[dict]
1652
+ ) -> Generator:
1653
+ for instance, score in zip(stream, scores):
1654
+ instance["score"] = recursive_copy(score)
1655
+ yield instance
 
 
 
 
 
 
 
 
 
1656
 
1657
+ # to be populated only when two or more metrics
1658
+ accumulated_scores = []
1659
+
1660
+ first_instance = stream.peek()
1661
 
1662
  metric_names = first_instance.get(self.metric_field, [])
1663
  if not metric_names:
 
1686
  # by the first listed metric (as desired).
1687
  metrics_list = list(reversed(metrics_list))
1688
 
1689
+ for metric_no, metric in enumerate(metrics_list):
1690
  if not self.calc_confidence_intervals:
1691
  metric.disable_confidence_interval_calculation()
1692
+
1693
+ if metric_no > 0:
1694
+ # update input stream with accumulated scores
1695
+ reusable_generator = ReusableGenerator(
1696
+ generator=update_scores_of_stream_instances,
1697
+ gen_kwargs={"stream": stream, "scores": accumulated_scores},
 
 
 
 
 
 
 
 
1698
  )
1699
+ multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
1700
+ else:
1701
+ multi_stream = MultiStream.from_iterables({"tmp": stream})
1702
+ multi_stream = metric(multi_stream)
1703
+ if metric_no < len(metrics_list) - 1:
1704
+ # not the last metric, so prepare for the next metric by
1705
+ # updating accumulated_scores
1706
+ accumulated_scores = []
1707
+ for inst in multi_stream["tmp"]:
1708
+ accumulated_scores.append(recursive_copy(inst["score"]))
1709
 
1710
+ yield from multi_stream["tmp"]
1711
 
1712
 
1713
  class MergeStreams(MultiStreamOperator):
 
1880
  input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
1881
  of the leading 'max_instances' of the input stream.
1882
 
1883
+ Args:
1884
+ max_instances (int)
1885
+ apply_to_streams (optional, list(str)):
1886
+ names of streams to refine.
1887
 
1888
  Examples:
1889
+ when input = ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}]`` is fed into
1890
+ ``StreamRefiner(max_instances=4)``
1891
+ the resulting stream is ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4}]``
1892
  """
1893
 
1894
  max_instances: int = None
 
1909
  When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
1910
  'max_instances'. The total number of discarded instances is as few as possible.
1911
 
1912
+ Args:
1913
+ fields (List[str]):
1914
+ A list of field names to be used in producing the instance's signature.
1915
+ max_instances (Optional, int):
1916
+ overall max.
1917
 
1918
  Usage:
1919
+ ``balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)``
1920
+ ``balanced_stream = balancer.process(stream)``
1921
 
1922
  Example:
1923
+ When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}]`` is fed into
1924
+ ``DeterministicBalancer(fields=["a"])``
1925
+ the resulting stream will be: ``[{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]``
1926
  """
1927
 
1928
  fields: List[str]
 
1959
  class MinimumOneExamplePerLabelRefiner(StreamRefiner):
1960
  """A class used to return a specified number instances ensuring at least one example per label.
1961
 
1962
+ For each instance, a signature value is constructed from the values of the instance in specified input ``fields``.
1963
+ ``MinimumOneExamplePerLabelRefiner`` takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit. In general, the refiner takes the first elements in the stream that meet the required conditions.
1964
+ ``MinimumOneExamplePerLabelRefiner`` then shuffles the results to avoid having one instance
1965
  from each class first and then the rest . If max instance is not set, the original stream will be used
1966
 
1967
+ Args:
1968
+ fields (List[str]):
1969
+ A list of field names to be used in producing the instance's signature.
1970
+ max_instances (Optional, int):
1971
+ Number of elements to select. Note that max_instances of StreamRefiners
1972
+ that are passed to the recipe (e.g. ``train_refiner``. ``test_refiner``) are overridden
1973
+ by the recipe parameters ( ``max_train_instances``, ``max_test_instances``)
1974
 
1975
  Usage:
1976
+ | ``balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)``
1977
+ | ``balanced_stream = balancer.process(stream)``
1978
 
1979
  Example:
1980
+ When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}]`` is fed into
1981
+ ``MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)``
1982
  the resulting stream will be:
1983
+ ``[{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}]`` (order may be different)
1984
  """
1985
 
1986
  fields: List[str]
 
2038
  """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
2039
 
2040
  Args:
2041
+ segments_boundaries (List[int]):
2042
+ distinct integers sorted in increasing order, that map a given total length
2043
+ into the index of the least of them that exceeds the given total length.
2044
+ (If none exceeds -- into one index beyond, namely, the length of segments_boundaries)
2045
+ fields (Optional, List[str]):
2046
+ the total length of the values of these fields goes through the quantization described above
2047
 
 
2048
 
2049
  Example:
2050
+ when input ``[{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}]``
2051
+ is fed into ``LengthBalancer(fields=["a"], segments_boundaries=[1])``,
2052
+ input instances will be counted and balanced against two categories:
2053
+ empty total length (less than 1), and non-empty.
 
 
 
2054
  """
2055
 
2056
  segments_boundaries: List[int]
 
2082
  class DownloadOperator(SideEffectOperator):
2083
  """Operator for downloading a file from a given URL to a specified local path.
2084
 
2085
+ Args:
2086
+ source (str):
2087
+ URL of the file to be downloaded.
2088
+ target (str):
2089
+ Local path where the downloaded file should be saved.
2090
  """
2091
 
2092
  source: str
 
2106
  class ExtractZipFile(SideEffectOperator):
2107
  """Operator for extracting files from a zip archive.
2108
 
2109
+ Args:
2110
+ zip_file (str):
2111
+ Path of the zip file to be extracted.
2112
+ target_dir (str):
2113
+ Directory where the contents of the zip file will be extracted.
2114
  """
2115
 
2116
  zip_file: str
 
2124
  class DuplicateInstances(StreamOperator):
2125
  """Operator which duplicates each instance in stream a given number of times.
2126
 
2127
+ Args:
2128
+ num_duplications (int):
2129
+ How many times each instance should be duplicated (1 means no duplication).
2130
  duplication_index_field (Optional[str]):
2131
  If given, then additional field with specified name is added to each duplicated instance,
2132
  which contains id of a given duplication. Defaults to None, so no field is added.
processors.py CHANGED
@@ -132,6 +132,14 @@ class TakeFirstNonEmptyLine(FieldOperator):
132
  return parts[0].strip()
133
 
134
 
 
 
 
 
 
 
 
 
135
  class ConvertToBoolean(FieldOperator):
136
  def process_value(self, text: Any) -> Any:
137
  clean_instance = str(text).strip().lower()
@@ -157,6 +165,11 @@ class Lower(FieldOperator):
157
  return text.lower()
158
 
159
 
 
 
 
 
 
160
  @deprecation("2.0.0", alternative=Lower)
161
  class LowerCase(Lower):
162
  pass
 
132
  return parts[0].strip()
133
 
134
 
135
+ class TakeLastNonEmptyLine(FieldOperator):
136
+ def process_value(self, text: Any) -> Any:
137
+ parts = str(text).strip().split("\n")
138
+ if len(parts) == 0:
139
+ return ""
140
+ return parts[-1].strip()
141
+
142
+
143
  class ConvertToBoolean(FieldOperator):
144
  def process_value(self, text: Any) -> Any:
145
  clean_instance = str(text).strip().lower()
 
165
  return text.lower()
166
 
167
 
168
+ class Upper(FieldOperator):
169
+ def process_value(self, text: Any) -> Any:
170
+ return str(text).upper()
171
+
172
+
173
  @deprecation("2.0.0", alternative=Lower)
174
  class LowerCase(Lower):
175
  pass
schema.py CHANGED
@@ -143,6 +143,9 @@ class FinalizeDataset(InstanceOperatorValidator):
143
  )
144
 
145
  task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"]
 
 
 
146
  task_data["metadata"]["template"] = self.artifact_to_jsonable(
147
  instance["recipe_metadata"]["template"]
148
  )
 
143
  )
144
 
145
  task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"]
146
+ task_data["metadata"]["demos_pool_size"] = instance["recipe_metadata"][
147
+ "demos_pool_size"
148
+ ]
149
  task_data["metadata"]["template"] = self.artifact_to_jsonable(
150
  instance["recipe_metadata"]["template"]
151
  )
settings_utils.py CHANGED
@@ -138,7 +138,7 @@ if Settings.is_uninitilized():
138
  settings.max_log_message_size = (int, 100000)
139
  settings.catalogs = None
140
  settings.artifactories = None
141
- settings.default_recipe = "standard_recipe"
142
  settings.default_verbosity = "info"
143
  settings.use_eager_execution = False
144
  settings.remote_metrics = []
@@ -186,6 +186,7 @@ if Constants.is_uninitilized():
186
  constants.inference_stream = "__INFERENCE_STREAM__"
187
  constants.instance_stream = "__INSTANCE_STREAM__"
188
  constants.image_tag = "unitxt-img"
 
189
 
190
 
191
  def get_settings() -> Settings:
 
138
  settings.max_log_message_size = (int, 100000)
139
  settings.catalogs = None
140
  settings.artifactories = None
141
+ settings.default_recipe = "dataset_recipe"
142
  settings.default_verbosity = "info"
143
  settings.use_eager_execution = False
144
  settings.remote_metrics = []
 
186
  constants.inference_stream = "__INFERENCE_STREAM__"
187
  constants.instance_stream = "__INSTANCE_STREAM__"
188
  constants.image_tag = "unitxt-img"
189
+ constants.demos_pool_field = "_demos_pool_"
190
 
191
 
192
  def get_settings() -> Settings:
span_lableing_operators.py CHANGED
@@ -6,19 +6,18 @@ from .operator import InstanceOperator
6
  class IobExtractor(InstanceOperator):
7
  """A class designed to extract entities from sequences of text using the Inside-Outside-Beginning (IOB) tagging convention. It identifies entities based on IOB tags and categorizes them into predefined labels such as Person, Organization, and Location.
8
 
9
- Attributes:
10
- labels (List[str]): A list of entity type labels, e.g., ["Person", "Organization", "Location"].
11
-
12
- begin_labels (List[str]): A list of labels indicating the beginning of an entity, e.g., ["B-PER", "B-ORG", "B-LOC"].
13
-
14
- inside_labels (List[str]): A list of labels indicating the continuation of an entity, e.g., ["I-PER", "I-ORG", "I-LOC"].
15
-
16
- outside_label (str): The label indicating tokens outside of any entity, typically "O".
 
17
 
18
  The extraction process identifies spans of text corresponding to entities and labels them according to their entity type. Each span is annotated with a start and end character offset, the entity text, and the corresponding label.
19
 
20
-
21
-
22
  Example of instantiation and usage:
23
 
24
  .. code-block:: python
 
6
  class IobExtractor(InstanceOperator):
7
  """A class designed to extract entities from sequences of text using the Inside-Outside-Beginning (IOB) tagging convention. It identifies entities based on IOB tags and categorizes them into predefined labels such as Person, Organization, and Location.
8
 
9
+ Args:
10
+ labels (List[str]):
11
+ A list of entity type labels, e.g., ["Person", "Organization", "Location"].
12
+ begin_labels (List[str]):
13
+ A list of labels indicating the beginning of an entity, e.g., ["B-PER", "B-ORG", "B-LOC"].
14
+ inside_labels (List[str]):
15
+ A list of labels indicating the continuation of an entity, e.g., ["I-PER", "I-ORG", "I-LOC"].
16
+ outside_label (str):
17
+ The label indicating tokens outside of any entity, typically "O".
18
 
19
  The extraction process identifies spans of text corresponding to entities and labels them according to their entity type. Each span is annotated with a start and end character offset, the entity text, and the corresponding label.
20
 
 
 
21
  Example of instantiation and usage:
22
 
23
  .. code-block:: python
splitters.py CHANGED
@@ -1,11 +1,11 @@
1
  import itertools
2
  from abc import abstractmethod
3
  from difflib import get_close_matches
4
- from typing import Dict, List, Optional
5
 
6
  from .artifact import Artifact
7
  from .dict_utils import dict_get
8
- from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
9
  from .random_utils import new_random_generator
10
  from .split_utils import (
11
  parse_random_mix_string,
@@ -14,7 +14,7 @@ from .split_utils import (
14
  rename_split,
15
  slice_streams,
16
  )
17
- from .stream import EmptyStreamError, FaultyStreamError, MultiStream
18
  from .type_utils import isoftype
19
  from .utils import recursive_copy
20
 
@@ -118,14 +118,14 @@ class Sampler(Artifact):
118
  def sample(
119
  self,
120
  sample_size: int,
121
- instances_pool: List[Dict[str, object]],
122
- instance: Dict[str, object],
123
- ) -> List[Dict[str, object]]:
124
  pass
125
 
126
  def filter_source_by_instance(
127
- self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
128
- ) -> List[Dict[str, object]]:
129
  if "input_fields" not in instance:
130
  raise ValueError(f"'input_fields' field is missing from '{instance}'.")
131
  try:
@@ -336,10 +336,11 @@ class DiverseLabelsSampler(Sampler):
336
  return result
337
 
338
 
339
- class Sample(InstanceOperatorWithMultiStreamAccess):
340
- from_stream: str
341
  to_field: str
342
  sampler: Sampler
 
343
 
344
  def prepare(self):
345
  self.local_cache = None
@@ -350,40 +351,36 @@ class Sample(InstanceOperatorWithMultiStreamAccess):
350
  pass
351
 
352
  def process(
353
- self, instance: Dict[str, object], multi_stream: MultiStream
354
- ) -> Dict[str, object]:
355
- sample_size = self.get_sample_size(instance)
356
- try:
357
- if self.local_cache is None:
358
- self.local_cache = recursive_copy(list(multi_stream[self.from_stream]))
359
 
360
- source_stream = self.local_cache
361
- source_stream = self.sampler.filter_source_by_instance(
362
- source_stream, instance
363
- )
364
- if len(source_stream) < sample_size:
365
- raise ValueError(
366
- f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}."
367
- )
368
- sampled_instances = self.sampler.sample(
369
- sample_size=sample_size, instances_pool=source_stream, instance=instance
370
  )
371
- instance[self.to_field] = sampled_instances
372
- return instance
373
- except FaultyStreamError as e:
374
- raise EmptyStreamError(
375
- f"Unable to fetch instances from '{self.from_stream}' to '{self.to_field}', due to {e.__class__.__name__}: {e}"
376
- ) from e
377
 
378
 
379
- class ConstantSizeSample(Sample):
380
  sample_size: int
381
 
382
  def get_sample_size(self, instance) -> int:
383
  return self.sample_size
384
 
385
 
386
- class RandomSizeSample(Sample):
387
  sample_sizes: List[int]
388
 
389
  def get_sample_size(self, instance) -> int:
 
1
  import itertools
2
  from abc import abstractmethod
3
  from difflib import get_close_matches
4
+ from typing import Any, Dict, List, Optional
5
 
6
  from .artifact import Artifact
7
  from .dict_utils import dict_get
8
+ from .operator import InstanceOperator, MultiStreamOperator
9
  from .random_utils import new_random_generator
10
  from .split_utils import (
11
  parse_random_mix_string,
 
14
  rename_split,
15
  slice_streams,
16
  )
17
+ from .stream import MultiStream
18
  from .type_utils import isoftype
19
  from .utils import recursive_copy
20
 
 
118
  def sample(
119
  self,
120
  sample_size: int,
121
+ instances_pool: List[Dict[str, Any]],
122
+ instance: Dict[str, Any],
123
+ ) -> List[Dict[str, Any]]:
124
  pass
125
 
126
  def filter_source_by_instance(
127
+ self, instances_pool: List[Dict[str, Any]], instance: Dict[str, Any]
128
+ ) -> List[Dict[str, Any]]:
129
  if "input_fields" not in instance:
130
  raise ValueError(f"'input_fields' field is missing from '{instance}'.")
131
  try:
 
336
  return result
337
 
338
 
339
+ class AssignDemosToInstance(InstanceOperator):
340
+ from_field: str
341
  to_field: str
342
  sampler: Sampler
343
+ skip_demoed_instances: bool = False
344
 
345
  def prepare(self):
346
  self.local_cache = None
 
351
  pass
352
 
353
  def process(
354
+ self, instance: Dict[str, Any], multi_stream: MultiStream
355
+ ) -> Dict[str, Any]:
356
+ if self.skip_demoed_instances and self.to_field in instance:
357
+ if self.from_field in instance:
358
+ instance.pop(self.from_field)
359
+ return instance
360
 
361
+ demos_pool = instance[self.from_field]
362
+ sample_size = self.get_sample_size(instance)
363
+ source_stream = self.sampler.filter_source_by_instance(demos_pool, instance)
364
+ if len(source_stream) < sample_size:
365
+ raise ValueError(
366
+ f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {sample_size}. Please consider increasing increasing the demos pool, for which you may need to increase loader_limit or employ a less strict stream filtering."
 
 
 
 
367
  )
368
+ sampled_instances = self.sampler.sample(
369
+ sample_size=sample_size, instances_pool=source_stream, instance=instance
370
+ )
371
+ instance[self.to_field] = recursive_copy(sampled_instances)
372
+ instance.pop(self.from_field) # pop the field pointing to the demos_pool
373
+ return instance
374
 
375
 
376
+ class ConstantSizeSample(AssignDemosToInstance):
377
  sample_size: int
378
 
379
  def get_sample_size(self, instance) -> int:
380
  return self.sample_size
381
 
382
 
383
+ class RandomSizeSample(AssignDemosToInstance):
384
  sample_sizes: List[int]
385
 
386
  def get_sample_size(self, instance) -> int:
standard.py CHANGED
@@ -1,26 +1,35 @@
1
- from typing import List, Optional, Union
 
 
 
2
 
3
  from .artifact import fetch_artifact
4
  from .augmentors import Augmentor, NullAugmentor
5
  from .card import TaskCard
6
  from .collections_operators import GetLength
7
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
 
8
  from .error_utils import UnitxtError
9
  from .formats import Format, SystemFormat
 
10
  from .logging_utils import get_logger
11
- from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
 
 
 
 
 
12
  from .operators import Set, StreamRefiner
13
- from .recipe import Recipe
14
  from .schema import FinalizeDataset
15
  from .serializers import SingleTypeSerializer
16
  from .settings_utils import get_constants, get_settings
17
- from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
18
  from .stream import MultiStream
19
  from .system_prompts import EmptySystemPrompt, SystemPrompt
20
  from .task import Task
21
  from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList
22
  from .type_utils import isoftype
23
- from .utils import LRUCache
24
 
25
  constants = get_constants()
26
  settings = get_settings()
@@ -28,11 +37,205 @@ logger = get_logger()
28
 
29
 
30
  # Used to give meaningful name to recipe steps
31
- class CreateDemosPool(SeparateSplit):
32
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- class BaseRecipe(Recipe, SourceSequentialOperator):
36
  # Base parameters
37
  card: TaskCard = None
38
  task: Task = None
@@ -59,14 +262,18 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
59
  test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
60
 
61
  demos_pool_size: int = None
 
62
  num_demos: Optional[Union[int, List[int]]] = 0
63
  demos_removed_from_data: bool = True
 
64
 
65
- demos_pool_name: str = "demos_pool"
66
  demos_taken_from: str = "train"
67
  demos_field: str = "demos"
68
  sampler: Sampler = None
69
 
 
 
 
70
  augmentor: Union[Augmentor, List[Augmentor]] = OptionalField(default=None)
71
 
72
  steps: List[StreamingOperator] = InternalField(default_factory=list)
@@ -101,11 +308,16 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
101
  raise ValueError(
102
  "When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers."
103
  )
104
- if self.demos_pool_size < self.max_demos_size:
105
  raise ValueError(
106
- f"num_demos (got: {self.max_demos_size}) should not exceed demos_pool_size (got: {self.demos_pool_size})"
107
  )
108
- if self.loader_limit and self.demos_pool_size > self.loader_limit:
 
 
 
 
 
109
  raise ValueError(
110
  f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}"
111
  )
@@ -220,29 +432,21 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
220
  self.loading,
221
  self.metadata,
222
  self.standardization,
223
- self.processing,
224
  ]
225
 
226
  self.inference = SequentialOperator()
227
 
228
- self.inference.steps = [self.metadata, self.verbalization, self.finalize]
229
 
230
  def production_preprocess(self, task_instances):
231
  ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
232
- return list(self.inference_instance(ms)[constants.inference_stream])
233
-
234
- def production_demos_pool(self):
235
- if self.use_demos:
236
- demos_pool = self.__class__._demos_pool_cache.get(str(self), None)
237
- if demos_pool is None:
238
- demos_pool = list(self.inference_demos()[self.demos_pool_name])
239
- self.__class__._demos_pool_cache[str(self)] = demos_pool
240
- return demos_pool
241
- return []
242
 
243
  @property
244
  def has_custom_demos_pool(self):
245
- return self.demos_pool_size is not None and self.demos_pool_size > 0
 
 
246
 
247
  @property
248
  def use_demos(self):
@@ -251,13 +455,22 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
251
  def produce(self, task_instances):
252
  """Use the recipe in production to produce model ready query from standard task instance."""
253
  self.before_process_multi_stream()
254
- streams = {
255
- constants.inference_stream: self.production_preprocess(task_instances),
256
- }
257
- if self.use_demos:
258
- streams[self.demos_pool_name] = self.production_demos_pool()
259
- multi_stream = MultiStream.from_iterables(streams)
260
- multi_stream = self.inference(multi_stream)
 
 
 
 
 
 
 
 
 
261
  return list(multi_stream[constants.inference_stream])
262
 
263
  def reset(self):
@@ -321,15 +534,29 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
321
  augmentor.set_fields(self.card.task.augmentable_inputs)
322
  self.processing.steps.append(augmentor)
323
 
 
 
 
 
324
  if self.has_custom_demos_pool:
325
- self.processing.steps.append(
326
- CreateDemosPool(
327
- from_split=self.demos_taken_from,
328
- to_split_names=[self.demos_pool_name, self.demos_taken_from],
329
- to_split_sizes=[int(self.demos_pool_size)],
330
- remove_targets_from_source_split=self.demos_removed_from_data,
 
 
 
 
 
 
 
 
 
 
 
331
  )
332
- )
333
 
334
  if self.use_demos:
335
  if self.sampler is None:
@@ -346,28 +573,41 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
346
  if isinstance(self.num_demos, int):
347
  self.verbalization.steps.append(
348
  ConstantSizeSample(
349
- from_stream=self.demos_pool_name,
350
  to_field=self.demos_field,
351
  sampler=self.sampler,
352
  sample_size=self.num_demos,
 
353
  )
354
  )
355
  self.verbalization.steps.append(
356
- Set(fields={"recipe_metadata/num_demos": self.num_demos})
 
 
 
 
 
357
  )
358
 
359
  elif isinstance(self.num_demos, list):
360
  self.verbalization.steps.append(
361
  RandomSizeSample(
362
- from_stream=self.demos_pool_name,
363
  to_field=self.demos_field,
364
  sampler=self.sampler,
365
  sample_sizes=self.num_demos,
 
366
  )
367
  )
368
  self.verbalization.steps.append(
369
  GetLength(field="demos", to_field="recipe_metadata/num_demos")
370
  )
 
 
 
 
 
 
371
  else:
372
  raise ValueError("num_demos must be int or List[int]")
373
 
@@ -383,9 +623,15 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
383
  template=self.template, demos_field=self.demos_field
384
  )
385
  )
 
386
  else:
387
  self.verbalization.steps.append(
388
- Set(fields={"recipe_metadata/num_demos": 0})
 
 
 
 
 
389
  )
390
  if isinstance(self.template, list):
391
  self.verbalization.steps.append(
@@ -409,15 +655,6 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
409
 
410
  self.finalize.steps.append(FinalizeDataset(group_by=self.group_by))
411
 
412
- def prepare(self):
413
- if isinstance(self.template, TemplatesList):
414
- self.template = self.template.items
415
- self.reset_pipeline()
416
-
417
-
418
- class StandardRecipeWithIndexes(BaseRecipe):
419
- template_card_index: int = None
420
-
421
  def prepare(self):
422
  assert (
423
  self.template_card_index is None or self.template is None
@@ -464,77 +701,41 @@ class StandardRecipeWithIndexes(BaseRecipe):
464
  raise ValueError(
465
  "No template was specified in the the 'template' or 'template_card_index' recipe arguments, and no default templates are defined the card or task"
466
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
- super().prepare()
 
 
 
 
469
 
 
 
 
470
 
471
- class StandardRecipe(StandardRecipeWithIndexes):
472
- """This class represents a standard recipe for data processing and preparation.
473
 
474
- This class can be used to prepare a recipe.
475
- with all necessary steps, refiners and renderers included. It allows to set various
476
- parameters and steps in a sequential manner for preparing the recipe.
477
 
478
- Args:
479
- card (TaskCard):
480
- TaskCard object associated with the recipe.
481
- template (Template, optional):
482
- Template object to be used for the recipe.
483
- system_prompt (SystemPrompt, optional):
484
- SystemPrompt object to be used for the recipe.
485
- loader_limit (int, optional):
486
- Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
487
- format (SystemFormat, optional):
488
- SystemFormat object to be used for the recipe.
489
- metrics (List[str]):
490
- list of catalog metrics to use with this recipe.
491
- postprocessors (List[str]):
492
- list of catalog processors to apply at post processing. (Not recommended to use from here)
493
- group_by (List[Union[str, List[str]]]):
494
- list of task_data or metadata keys to group global scores by.
495
- train_refiner (StreamRefiner, optional):
496
- Train refiner to be used in the recipe.
497
- max_train_instances (int, optional):
498
- Maximum training instances for the refiner.
499
- validation_refiner (StreamRefiner, optional):
500
- Validation refiner to be used in the recipe.
501
- max_validation_instances (int, optional):
502
- Maximum validation instances for the refiner.
503
- test_refiner (StreamRefiner, optional):
504
- Test refiner to be used in the recipe.
505
- max_test_instances (int, optional):
506
- Maximum test instances for the refiner.
507
- demos_pool_size (int, optional):
508
- Size of the demos pool.
509
- num_demos (int, optional):
510
- Number of demos to be used.
511
- demos_pool_name (str, optional):
512
- Name of the demos pool. Default is "demos_pool".
513
- demos_taken_from (str, optional):
514
- Specifies from where the demos are taken. Default is "train".
515
- demos_field (str, optional):
516
- Field name for demos. Default is "demos".
517
- demos_removed_from_data (bool, optional):
518
- whether to remove the demos from the source data, Default is True
519
- sampler (Sampler, optional):
520
- The Sampler used to select the demonstrations when num_demos > 0.
521
- steps (List[StreamingOperator], optional):
522
- List of StreamingOperator objects to be used in the recipe.
523
- augmentor (Augmentor) :
524
- Augmentor to be used to pseudo randomly augment the source text
525
- instruction_card_index (int, optional):
526
- Index of instruction card to be used for preparing the recipe.
527
- template_card_index (int, optional):
528
- Index of template card to be used for preparing the recipe.
529
 
530
- Methods:
531
- prepare():
532
- This overridden method is used for preparing the recipe
533
- by arranging all the steps, refiners, and renderers in a sequential manner.
534
 
535
- Raises:
536
- AssertionError:
537
- If both template and template_card_index are specified at the same time.
538
- """
539
 
 
 
540
  pass
 
1
+ import itertools
2
+ import json
3
+ import sys
4
+ from typing import Any, Dict, Generator, List, Optional, Union
5
 
6
  from .artifact import fetch_artifact
7
  from .augmentors import Augmentor, NullAugmentor
8
  from .card import TaskCard
9
  from .collections_operators import GetLength
10
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
11
+ from .deprecation_utils import deprecation
12
  from .error_utils import UnitxtError
13
  from .formats import Format, SystemFormat
14
+ from .generator_utils import ReusableGenerator
15
  from .logging_utils import get_logger
16
+ from .operator import (
17
+ MultiStreamOperator,
18
+ SequentialOperator,
19
+ SourceSequentialOperator,
20
+ StreamingOperator,
21
+ )
22
  from .operators import Set, StreamRefiner
 
23
  from .schema import FinalizeDataset
24
  from .serializers import SingleTypeSerializer
25
  from .settings_utils import get_constants, get_settings
26
+ from .splitters import ConstantSizeSample, RandomSizeSample, Sampler
27
  from .stream import MultiStream
28
  from .system_prompts import EmptySystemPrompt, SystemPrompt
29
  from .task import Task
30
  from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList
31
  from .type_utils import isoftype
32
+ from .utils import LRUCache, recursive_copy
33
 
34
  constants = get_constants()
35
  settings = get_settings()
 
37
 
38
 
39
  # Used to give meaningful name to recipe steps
40
+ class CreateDemosPool(MultiStreamOperator):
41
+ from_stream: str = None
42
+ demos_pool_size: int = None
43
+ demos_removed_from_data: bool = None
44
+ to_field: str = constants.demos_pool_field
45
+
46
+ # flake8: noqa: B007
47
+ def process(self, multi_stream: MultiStream) -> MultiStream:
48
+ # generate the demos_pool as a selection of demos_pool_size distinct instances
49
+ # (distinct by their "input_fields" field). The selection is taken from stream named from_stream.
50
+ # The selected instances are later treated as ordinary instances or not, depending on parameter
51
+ # demos_removed_from_data.
52
+ # The selection of instances is done from the first instances of the stream named from_stream.
53
+ # instances that are not distinct from previously selected demo instances, are kept aside, to be later
54
+ # treated like all the remaining instances of stream from_stream.
55
+ if self.from_stream not in multi_stream:
56
+ raise ValueError(
57
+ f"Input multi-stream is missing a stream named '{self.from_stream}' to take demo instances from for the demos_pool."
58
+ )
59
+ if (
60
+ self.demos_removed_from_data is not None
61
+ and self.demos_removed_from_data is True
62
+ and (self.demos_pool_size == sys.maxsize)
63
+ ):
64
+ # going to consume the whole of input stream named self.from_stream for demo instances,
65
+ # and not let demos instances to behave as regular instances. so self.from_stream
66
+ # ends here its life as an input stream that is expected to reach the end of the recipe
67
+ if len(multi_stream) == 1:
68
+ raise ValueError(
69
+ f"The single input stream, '{self.from_stream}' is to be wholly consumed for generating demos, and no instance is left to use these demos."
70
+ )
71
+ from_stream = multi_stream[self.from_stream]
72
+ demos_pool = []
73
+ input_fields_of_demos_pool = []
74
+ not_selected_from_from_stream = []
75
+ for num_scanned, instance in enumerate(from_stream):
76
+ if "input_fields" not in instance:
77
+ raise ValueError(f"'input_fields' field is missing from '{instance}'.")
78
+ input_fields_signature = json.dumps(
79
+ instance["input_fields"], sort_keys=True
80
+ )
81
+ if input_fields_signature in input_fields_of_demos_pool:
82
+ not_selected_from_from_stream.append(instance)
83
+ continue
84
+ demos_pool.append(instance)
85
+ input_fields_of_demos_pool.append(input_fields_signature)
86
+ if len(demos_pool) >= self.demos_pool_size:
87
+ break
88
+
89
+ # for backward compatibility, do not throw exception here if demos pool is smaller than expected.
90
+ # Delay that for the event (if occurs) that Sample is not be able to sample num_demos demos.
91
+
92
+ # to avoid endless recursion in case of not demos_removed_from_data
93
+ demos_pool = recursive_copy(demos_pool)
94
+
95
+ set_demos_pool = Set(fields={self.to_field: demos_pool})
96
+ if (
97
+ self.demos_removed_from_data is not None
98
+ and self.demos_removed_from_data is False
99
+ ):
100
+ # all input instances go out. No one is "killed" because selected as demo
101
+ return set_demos_pool(multi_stream)
102
+
103
+ if (
104
+ self.demos_removed_from_data is not None
105
+ and self.demos_removed_from_data is True
106
+ ):
107
+ if self.demos_pool_size == sys.maxsize:
108
+ # consume the whole of input stream self.from_stream, just for demos, and do not
109
+ # take any of its instances to behave as a non-demo instance, i.e., a regular instance
110
+ # that consume the demos
111
+ out_ms = MultiStream(
112
+ {
113
+ stream_name: multi_stream[stream_name]
114
+ for stream_name in multi_stream
115
+ if stream_name != self.from_stream
116
+ }
117
+ )
118
+ return set_demos_pool(out_ms)
119
+
120
+ # self.demos_removed_from_data and not consume the whole of self.from_stream just for demos
121
+ def from_stream_generator(
122
+ first_layer: list, ms: MultiStream, stream_name: str, start: int
123
+ ) -> Generator:
124
+ yield from first_layer
125
+ yield from itertools.islice(ms[stream_name], start, None)
126
+
127
+ new_streams = {}
128
+ for stream_name in multi_stream:
129
+ if stream_name == self.from_stream:
130
+ new_streams[stream_name] = ReusableGenerator(
131
+ generator=from_stream_generator,
132
+ gen_kwargs={
133
+ "first_layer": not_selected_from_from_stream,
134
+ "ms": multi_stream,
135
+ "stream_name": self.from_stream,
136
+ "start": num_scanned + 1,
137
+ },
138
+ )
139
+ else:
140
+ new_streams[stream_name] = ReusableGenerator(
141
+ generator=from_stream_generator,
142
+ gen_kwargs={
143
+ "first_layer": [],
144
+ "ms": multi_stream,
145
+ "stream_name": stream_name,
146
+ "start": 0,
147
+ },
148
+ )
149
+
150
+ ms = MultiStream.from_generators(new_streams)
151
+ return set_demos_pool(ms)
152
+
153
+
154
+ class AddDemosPool(MultiStreamOperator):
155
+ demos_pool: List[Dict[str, Any]]
156
+ demos_pool_field_name: str = constants.demos_pool_field
157
+
158
+ def process(self, multi_stream: MultiStream) -> MultiStream:
159
+ set_demos_pool = Set(fields={self.demos_pool_field_name: self.demos_pool})
160
+ return set_demos_pool(multi_stream)
161
+
162
+
163
+ class DatasetRecipe(SourceSequentialOperator):
164
+ """This class represents a standard recipe for data processing and preparation.
165
 
166
+ This class can be used to prepare a recipe.
167
+ with all necessary steps, refiners and renderers included. It allows to set various
168
+ parameters and steps in a sequential manner for preparing the recipe.
169
+
170
+ Args:
171
+ card (TaskCard):
172
+ TaskCard object associated with the recipe.
173
+ template (Template, optional):
174
+ Template object to be used for the recipe.
175
+ system_prompt (SystemPrompt, optional):
176
+ SystemPrompt object to be used for the recipe.
177
+ loader_limit (int, optional):
178
+ Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
179
+ format (SystemFormat, optional):
180
+ SystemFormat object to be used for the recipe.
181
+ metrics (List[str]):
182
+ list of catalog metrics to use with this recipe.
183
+ postprocessors (List[str]):
184
+ list of catalog processors to apply at post processing. (Not recommended to use from here)
185
+ group_by (List[Union[str, List[str]]]):
186
+ list of task_data or metadata keys to group global scores by.
187
+ train_refiner (StreamRefiner, optional):
188
+ Train refiner to be used in the recipe.
189
+ max_train_instances (int, optional):
190
+ Maximum training instances for the refiner.
191
+ validation_refiner (StreamRefiner, optional):
192
+ Validation refiner to be used in the recipe.
193
+ max_validation_instances (int, optional):
194
+ Maximum validation instances for the refiner.
195
+ test_refiner (StreamRefiner, optional):
196
+ Test refiner to be used in the recipe.
197
+ max_test_instances (int, optional):
198
+ Maximum test instances for the refiner.
199
+ demos_pool_size (int, optional):
200
+ Size of the demos pool. -1 for taking the whole of stream 'demos_taken_from'.
201
+ demos_pool(List[Dict[str, Any]], optional):
202
+ a list of instances to make the demos_pool
203
+ num_demos (int, optional):
204
+ Number of demos to add to each instance, to become part of the source to be generated for this instance.
205
+ demos_taken_from (str, optional):
206
+ Specifies the stream from where the demos are taken. Default is "train".
207
+ demos_field (str, optional):
208
+ Field name for demos. Default is "demos".
209
+ The num_demos demos selected for an instance are stored in this field of that instance.
210
+ demos_pool_field_name (str, optional):
211
+ field name to maintain the demos_pool, until sampled from, in order to make the demos.
212
+ Defaults to constants.demos_pool_field.
213
+ demos_removed_from_data (bool, optional):
214
+ whether to remove the demos taken to demos_pool from the source data, Default is True
215
+ sampler (Sampler, optional):
216
+ The Sampler used to select the demonstrations when num_demos > 0.
217
+ skip_demoed_instances (bool, optional):
218
+ whether to skip pushing demos to an instance whose demos_field is
219
+ already populated. Defaults to False.
220
+ steps (List[StreamingOperator], optional):
221
+ List of StreamingOperator objects to be used in the recipe.
222
+ augmentor (Augmentor) :
223
+ Augmentor to be used to pseudo randomly augment the source text
224
+ instruction_card_index (int, optional):
225
+ Index of instruction card to be used for preparing the recipe.
226
+ template_card_index (int, optional):
227
+ Index of template card to be used for preparing the recipe.
228
+
229
+ Methods:
230
+ prepare():
231
+ This overridden method is used for preparing the recipe
232
+ by arranging all the steps, refiners, and renderers in a sequential manner.
233
+
234
+ Raises:
235
+ AssertionError:
236
+ If both template and template_card_index are specified at the same time.
237
+ """
238
 
 
239
  # Base parameters
240
  card: TaskCard = None
241
  task: Task = None
 
262
  test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
263
 
264
  demos_pool_size: int = None
265
+ demos_pool: List[Dict[str, Any]] = None
266
  num_demos: Optional[Union[int, List[int]]] = 0
267
  demos_removed_from_data: bool = True
268
+ demos_pool_field_name: str = constants.demos_pool_field
269
 
 
270
  demos_taken_from: str = "train"
271
  demos_field: str = "demos"
272
  sampler: Sampler = None
273
 
274
+ # do not push demos to instances whose "demos" field is already populated
275
+ skip_demoed_instances: bool = False
276
+
277
  augmentor: Union[Augmentor, List[Augmentor]] = OptionalField(default=None)
278
 
279
  steps: List[StreamingOperator] = InternalField(default_factory=list)
 
308
  raise ValueError(
309
  "When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers."
310
  )
311
+ if self.demos_pool_size < self.max_demos_size + 1:
312
  raise ValueError(
313
+ f"num_demos (got: {self.max_demos_size}) should not exceed demos_pool_size - 1 (got: {self.demos_pool_size}), (-1: to always allow filtering of a demo identical to the processed instance)."
314
  )
315
+ if (
316
+ (not self.demos_pool)
317
+ and (self.demos_pool_size != sys.maxsize)
318
+ and self.loader_limit
319
+ and (self.demos_pool_size > self.loader_limit)
320
+ ):
321
  raise ValueError(
322
  f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}"
323
  )
 
432
  self.loading,
433
  self.metadata,
434
  self.standardization,
 
435
  ]
436
 
437
  self.inference = SequentialOperator()
438
 
439
+ self.inference.steps = [self.processing, self.verbalization, self.finalize]
440
 
441
  def production_preprocess(self, task_instances):
442
  ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
443
+ return list(self.metadata(ms)[constants.inference_stream])
 
 
 
 
 
 
 
 
 
444
 
445
  @property
446
  def has_custom_demos_pool(self):
447
+ return self.demos_pool_size is not None and (
448
+ self.demos_pool_size > 0 or self.demos_pool_size == -1
449
+ )
450
 
451
  @property
452
  def use_demos(self):
 
455
  def produce(self, task_instances):
456
  """Use the recipe in production to produce model ready query from standard task instance."""
457
  self.before_process_multi_stream()
458
+
459
+ ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
460
+ # does not hurt to set metadata
461
+ # task_instances are assumed to be as if passed through self.standardization
462
+ ms = self.metadata(ms)
463
+ if not self.use_demos:
464
+ # go with task_instances all the way, it does not need other streams:
465
+ ms = self.inference(ms)
466
+ return list(ms[constants.inference_stream])
467
+
468
+ streams = self.inference_demos()
469
+ # streams stopped before processing
470
+ # ms is ready to join, it will get the demos from streams
471
+ streams[constants.inference_stream] = ms[constants.inference_stream]
472
+ # multi_stream = MultiStream(streams)
473
+ multi_stream = self.inference(streams)
474
  return list(multi_stream[constants.inference_stream])
475
 
476
  def reset(self):
 
534
  augmentor.set_fields(self.card.task.augmentable_inputs)
535
  self.processing.steps.append(augmentor)
536
 
537
+ # for backward compatibility, consume the demos instances even if not pushed into demos field of the ordinary instances,
538
+ # in order to use the very same ordinary instances as in back releases.
539
+ # one example of consume but not used, and indeed skips over a problematic (json-wise) input:
540
+ # prepare/cards/rag/end_to_end/clapnq.py
541
  if self.has_custom_demos_pool:
542
+ if self.demos_pool:
543
+ self.processing.steps.append(
544
+ AddDemosPool(
545
+ demos_pool=self.demos_pool,
546
+ demos_pool_field_name=self.demos_pool_field_name,
547
+ )
548
+ )
549
+ else:
550
+ self.processing.steps.append(
551
+ CreateDemosPool(
552
+ from_stream=self.demos_taken_from,
553
+ demos_pool_size=self.demos_pool_size
554
+ if self.demos_pool is None
555
+ else None,
556
+ demos_removed_from_data=self.demos_removed_from_data,
557
+ to_field=self.demos_pool_field_name,
558
+ )
559
  )
 
560
 
561
  if self.use_demos:
562
  if self.sampler is None:
 
573
  if isinstance(self.num_demos, int):
574
  self.verbalization.steps.append(
575
  ConstantSizeSample(
576
+ from_field=self.demos_pool_field_name,
577
  to_field=self.demos_field,
578
  sampler=self.sampler,
579
  sample_size=self.num_demos,
580
+ skip_demoed_instances=self.skip_demoed_instances,
581
  )
582
  )
583
  self.verbalization.steps.append(
584
+ Set(
585
+ fields={
586
+ "recipe_metadata/num_demos": self.num_demos,
587
+ "recipe_metadata/demos_pool_size": self.demos_pool_size,
588
+ }
589
+ )
590
  )
591
 
592
  elif isinstance(self.num_demos, list):
593
  self.verbalization.steps.append(
594
  RandomSizeSample(
595
+ from_field=self.demos_pool_field_name,
596
  to_field=self.demos_field,
597
  sampler=self.sampler,
598
  sample_sizes=self.num_demos,
599
+ skip_demoed_instances=self.skip_demoed_instances,
600
  )
601
  )
602
  self.verbalization.steps.append(
603
  GetLength(field="demos", to_field="recipe_metadata/num_demos")
604
  )
605
+ self.verbalization.steps.append(
606
+ Set(
607
+ fields={"recipe_metadata/demos_pool_size": self.demos_pool_size}
608
+ )
609
+ )
610
+
611
  else:
612
  raise ValueError("num_demos must be int or List[int]")
613
 
 
623
  template=self.template, demos_field=self.demos_field
624
  )
625
  )
626
+
627
  else:
628
  self.verbalization.steps.append(
629
+ Set(
630
+ fields={
631
+ "recipe_metadata/num_demos": 0,
632
+ "recipe_metadata/demos_pool_size": 0,
633
+ }
634
+ )
635
  )
636
  if isinstance(self.template, list):
637
  self.verbalization.steps.append(
 
655
 
656
  self.finalize.steps.append(FinalizeDataset(group_by=self.group_by))
657
 
 
 
 
 
 
 
 
 
 
658
  def prepare(self):
659
  assert (
660
  self.template_card_index is None or self.template is None
 
701
  raise ValueError(
702
  "No template was specified in the the 'template' or 'template_card_index' recipe arguments, and no default templates are defined the card or task"
703
  )
704
+ if self.use_demos:
705
+ assert (
706
+ self.demos_pool is not None
707
+ and isoftype(self.demos_pool, List[Dict[str, Any]])
708
+ ) != (
709
+ self.demos_taken_from is not None
710
+ and self.demos_pool_size is not None
711
+ and self.demos_removed_from_data is not None
712
+ ), (
713
+ "The demos_pool must be specified by exactly one of two ways: explicitly, as a list of instances coming through parameter "
714
+ + "'demos_pool', or via parameters 'demos_taken_from', 'demos_pool_size', and 'demos_removed_from_data', "
715
+ + "that together direct its production."
716
+ )
717
 
718
+ # now set self.demos_pool_size for the checks done by verify
719
+ if self.demos_pool:
720
+ self.demos_pool_size = len(self.demos_pool)
721
+ if self.demos_pool_size is not None and self.demos_pool_size == -1:
722
+ self.demos_pool_size = sys.maxsize
723
 
724
+ if isinstance(self.template, TemplatesList):
725
+ self.template = self.template.items
726
+ self.reset_pipeline()
727
 
 
 
728
 
729
+ @deprecation(version="2.0.0", alternative=DatasetRecipe)
730
+ class BaseRecipe(DatasetRecipe):
731
+ pass
732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
 
734
+ @deprecation(version="2.0.0", alternative=DatasetRecipe)
735
+ class StandardRecipeWithIndexes(DatasetRecipe):
736
+ pass
 
737
 
 
 
 
 
738
 
739
+ @deprecation(version="2.0.0", alternative=DatasetRecipe)
740
+ class StandardRecipe(DatasetRecipe):
741
  pass
struct_data_operators.py CHANGED
@@ -679,7 +679,7 @@ class LoadJson(FieldOperator):
679
  except json.JSONDecodeError:
680
  return self.failure_value
681
  else:
682
- return json.loads(value)
683
 
684
 
685
  class DumpJson(FieldOperator):
 
679
  except json.JSONDecodeError:
680
  return self.failure_value
681
  else:
682
+ return json.loads(value, strict=False)
683
 
684
 
685
  class DumpJson(FieldOperator):
task.py CHANGED
@@ -40,25 +40,22 @@ def parse_string_types_instead_of_actual_objects(obj):
40
  class Task(InstanceOperator, ArtifactFetcherMixin):
41
  """Task packs the different instance fields into dictionaries by their roles in the task.
42
 
43
- Attributes:
44
  input_fields (Union[Dict[str, str], List[str]]):
45
- Dictionary with string names of instance input fields and types of respective values.
46
- In case a list is passed, each type will be assumed to be Any.
47
-
48
  reference_fields (Union[Dict[str, str], List[str]]):
49
- Dictionary with string names of instance output fields and types of respective values.
50
- In case a list is passed, each type will be assumed to be Any.
51
-
52
- metrics (List[str]): List of names of metrics to be used in the task.
53
-
54
  prediction_type (Optional[str]):
55
- Need to be consistent with all used metrics. Defaults to None, which means that it will
56
- be set to Any.
57
-
58
  defaults (Optional[Dict[str, Any]]):
59
- An optional dictionary with default values for chosen input/output keys. Needs to be
60
- consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments.
61
- Will not overwrite values if already provided in a given instance.
62
 
63
  The output instance contains three fields:
64
  1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
@@ -119,7 +116,7 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
119
  self.prediction_type
120
  )
121
 
122
- def verify(self):
123
  if hasattr(self, "inputs") and self.inputs is not None:
124
  depr_message = (
125
  "The 'inputs' field is deprecated. Please use 'input_fields' instead."
@@ -130,6 +127,9 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
130
  depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
131
  warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
132
 
 
 
 
133
  if self.input_fields is None:
134
  raise UnitxtError(
135
  "Missing attribute in task: 'input_fields' not set.",
@@ -155,7 +155,11 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
155
  f"will raise an exception.",
156
  Documentation.ADDING_TASK,
157
  )
158
- data = {key: Any for key in data}
 
 
 
 
159
  if io_type == "input_fields":
160
  self.input_fields = data
161
  else:
@@ -290,6 +294,9 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
290
  "media": instance.get("media", {}),
291
  "recipe_metadata": instance.get("recipe_metadata", {}),
292
  }
 
 
 
293
 
294
  if stream_name == constants.inference_stream:
295
  return result
 
40
  class Task(InstanceOperator, ArtifactFetcherMixin):
41
  """Task packs the different instance fields into dictionaries by their roles in the task.
42
 
43
+ Args:
44
  input_fields (Union[Dict[str, str], List[str]]):
45
+ Dictionary with string names of instance input fields and types of respective values.
46
+ In case a list is passed, each type will be assumed to be Any.
 
47
  reference_fields (Union[Dict[str, str], List[str]]):
48
+ Dictionary with string names of instance output fields and types of respective values.
49
+ In case a list is passed, each type will be assumed to be Any.
50
+ metrics (List[str]):
51
+ List of names of metrics to be used in the task.
 
52
  prediction_type (Optional[str]):
53
+ Need to be consistent with all used metrics. Defaults to None, which means that it will
54
+ be set to Any.
 
55
  defaults (Optional[Dict[str, Any]]):
56
+ An optional dictionary with default values for chosen input/output keys. Needs to be
57
+ consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments.
58
+ Will not overwrite values if already provided in a given instance.
59
 
60
  The output instance contains three fields:
61
  1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
 
116
  self.prediction_type
117
  )
118
 
119
+ def task_deprecations(self):
120
  if hasattr(self, "inputs") and self.inputs is not None:
121
  depr_message = (
122
  "The 'inputs' field is deprecated. Please use 'input_fields' instead."
 
127
  depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
128
  warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
129
 
130
+ def verify(self):
131
+ self.task_deprecations()
132
+
133
  if self.input_fields is None:
134
  raise UnitxtError(
135
  "Missing attribute in task: 'input_fields' not set.",
 
155
  f"will raise an exception.",
156
  Documentation.ADDING_TASK,
157
  )
158
+ if isinstance(data, dict):
159
+ data = parse_type_dict(to_type_dict(data))
160
+ else:
161
+ data = {key: Any for key in data}
162
+
163
  if io_type == "input_fields":
164
  self.input_fields = data
165
  else:
 
294
  "media": instance.get("media", {}),
295
  "recipe_metadata": instance.get("recipe_metadata", {}),
296
  }
297
+ if "demos" in instance:
298
+ # for the case of recipe.skip_demoed_instances
299
+ result["demos"] = instance["demos"]
300
 
301
  if stream_name == constants.inference_stream:
302
  return result
templates.py CHANGED
@@ -307,26 +307,27 @@ class PairwiseChoiceTemplate(InputOutputTemplate):
307
  The answer field value should be of type Literal["choice_a", "choice_b", "tie"]
308
 
309
  Args:
310
- choice_a_field (str): The field which contains choice_a value
311
-
312
- choice_b_field (str): The field which contains choice_b value
313
-
314
- answer_field (str): The field which contains the answer value.
315
- Should be of type Literal["choice_1", "choice_2", "tie"]
316
-
317
- choice_a_label (str): The label of choice A answer as it is verbalized in the template.
318
-
319
- choice_b_label (str): The label of choice B answer as it is verbalized in the template.
320
-
321
- choice_tie_label (str): The label of a tie answer as it should be verbalized in the template.
322
-
323
- shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
 
324
 
325
  shuffle: 50% of the time:
326
  1. The values of choice_a_field and choice_b_field will be swapped.
327
  2. If the values of answer_field is choice_a_label, set it to choice_b_label.
328
- | Else if the values of answer_field is choice_b_label, set it to choice_a_label.
329
- | Else if the value of answer_field is choice_tie_label, do nothing.
330
 
331
  """
332
 
@@ -636,21 +637,22 @@ class MultipleChoiceTemplate(InputFormatTemplate):
636
  class YesNoTemplate(InputFormatTemplate):
637
  """A template for generating binary Yes/No questions asking whether an input text is of a specific class.
638
 
639
- input_format:
640
- Defines the format of the question.
641
- class_field:
642
- Defines the field that contains the name of the class that this template
643
- asks of.
644
- label_field:
645
- Defines the field which contains the true label of the input text. If a gold label is equal to the
646
- value in class_name, then the correct output is self.yes_answer (by default, "Yes").
647
- Otherwise the correct output is self.no_answer (by default, "No").
648
- yes_answer:
649
- The output value for when the gold label equals self.class_name.
650
- Defaults to "Yes".
651
- no_answer:
652
- The output value for when the gold label differs from self.class_name.
653
- Defaults to "No".
 
654
  """
655
 
656
  input_format: str = None
 
307
  The answer field value should be of type Literal["choice_a", "choice_b", "tie"]
308
 
309
  Args:
310
+ choice_a_field (str):
311
+ The field which contains choice_a value
312
+ choice_b_field (str):
313
+ The field which contains choice_b value
314
+ answer_field (str):
315
+ The field which contains the answer value.
316
+ Should be of type Literal["choice_1", "choice_2", "tie"]
317
+ choice_a_label (str):
318
+ The label of choice A answer as it is verbalized in the template.
319
+ choice_b_label (str):
320
+ The label of choice B answer as it is verbalized in the template.
321
+ choice_tie_label (str):
322
+ The label of a tie answer as it should be verbalized in the template.
323
+ shuffle (bool):
324
+ whether to shuffle the choices or not. This is done to take into account position bias.
325
 
326
  shuffle: 50% of the time:
327
  1. The values of choice_a_field and choice_b_field will be swapped.
328
  2. If the values of answer_field is choice_a_label, set it to choice_b_label.
329
+ Else if the values of answer_field is choice_b_label, set it to choice_a_label.
330
+ Else if the value of answer_field is choice_tie_label, do nothing.
331
 
332
  """
333
 
 
637
  class YesNoTemplate(InputFormatTemplate):
638
  """A template for generating binary Yes/No questions asking whether an input text is of a specific class.
639
 
640
+ Args:
641
+ input_format:
642
+ Defines the format of the question.
643
+ class_field:
644
+ Defines the field that contains the name of the class that this template
645
+ asks of.
646
+ label_field:
647
+ Defines the field which contains the true label of the input text. If a gold label is equal to the
648
+ value in class_name, then the correct output is self.yes_answer (by default, "Yes").
649
+ Otherwise the correct output is self.no_answer (by default, "No").
650
+ yes_answer:
651
+ The output value for when the gold label equals self.class_name.
652
+ Defaults to "Yes".
653
+ no_answer:
654
+ The output value for when the gold label differs from self.class_name.
655
+ Defaults to "No".
656
  """
657
 
658
  input_format: str = None
type_utils.py CHANGED
@@ -552,6 +552,9 @@ def strtype(typing_type) -> str:
552
  - The function checks the `__origin__` attribute to determine the base type and formats
553
  the type arguments accordingly.
554
  """
 
 
 
555
  if not is_type(typing_type):
556
  raise UnsupportedTypeError(typing_type)
557
 
 
552
  - The function checks the `__origin__` attribute to determine the base type and formats
553
  the type arguments accordingly.
554
  """
555
+ if isinstance(typing_type, str):
556
+ return typing_type
557
+
558
  if not is_type(typing_type):
559
  raise UnsupportedTypeError(typing_type)
560
 
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.16.0"
 
1
+ version = "1.16.1"