Elron commited on
Commit
365fb61
·
verified ·
1 Parent(s): 91ef70a

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. api.py +52 -27
  2. formats.py +50 -0
  3. loaders.py +14 -74
  4. metrics.py +0 -1
  5. operators.py +205 -1
  6. settings_utils.py +4 -2
  7. version.py +1 -1
api.py CHANGED
@@ -1,10 +1,13 @@
 
1
  import inspect
2
  import json
 
3
  from datetime import datetime
4
  from functools import lru_cache
5
  from typing import Any, Dict, List, Optional, Union
6
 
7
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
 
8
 
9
  from .artifact import fetch_artifact
10
  from .card import TaskCard
@@ -19,7 +22,7 @@ from .loaders import LoadFromDictionary
19
  from .logging_utils import get_logger
20
  from .metric_utils import EvaluationResults, _compute, _inference_post_process
21
  from .operator import SourceOperator
22
- from .schema import UNITXT_DATASET_SCHEMA, loads_instance
23
  from .settings_utils import get_constants, get_settings
24
  from .standard import DatasetRecipe
25
  from .task import Task
@@ -29,13 +32,9 @@ constants = get_constants()
29
  settings = get_settings()
30
 
31
 
32
- def load(source: Union[SourceOperator, str]):
33
- assert isinstance(
34
- source, (SourceOperator, str)
35
- ), "source must be a SourceOperator or a string"
36
- if isinstance(source, str):
37
- source, _ = fetch_artifact(source)
38
- return source().to_dataset()
39
 
40
 
41
  def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
@@ -135,11 +134,44 @@ def create_dataset(
135
  return load_dataset(card=card, split=split, **kwargs)
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def load_dataset(
139
  dataset_query: Optional[str] = None,
140
  split: Optional[str] = None,
141
  streaming: bool = False,
142
- disable_cache: Optional[bool] = None,
143
  **kwargs,
144
  ) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
145
  """Loads dataset.
@@ -156,11 +188,16 @@ def load_dataset(
156
  local catalog or name of specific recipe or benchmark in the catalog. For
157
  example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
158
  streaming (bool, False):
159
- When True yields the data as Unitxt streams dictionary
 
 
160
  split (str, optional):
161
  The split of the data to load
162
- disable_cache (str, optional):
163
- Disable caching process of the data
 
 
 
164
  **kwargs:
165
  Arguments used to load dataset from provided card, which is not present in local catalog.
166
 
@@ -184,21 +221,9 @@ def load_dataset(
184
  """
185
  recipe = load_recipe(dataset_query, **kwargs)
186
 
187
- stream = recipe()
188
- if split is not None:
189
- stream = stream[split]
190
-
191
- if disable_cache is None:
192
- disable_cache = settings.disable_hf_datasets_cache
193
-
194
- if streaming:
195
- dataset = stream.to_iterable_dataset(
196
- features=UNITXT_DATASET_SCHEMA,
197
- ).map(loads_instance, batched=True)
198
- else:
199
- dataset = stream.to_dataset(
200
- features=UNITXT_DATASET_SCHEMA, disable_cache=disable_cache
201
- ).with_transform(loads_instance)
202
 
203
  frame = inspect.currentframe()
204
  args, _, _, values = inspect.getargvalues(frame)
 
1
+ import hashlib
2
  import inspect
3
  import json
4
+ import tempfile
5
  from datetime import datetime
6
  from functools import lru_cache
7
  from typing import Any, Dict, List, Optional, Union
8
 
9
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
10
+ from datasets.exceptions import DatasetGenerationError
11
 
12
  from .artifact import fetch_artifact
13
  from .card import TaskCard
 
22
  from .logging_utils import get_logger
23
  from .metric_utils import EvaluationResults, _compute, _inference_post_process
24
  from .operator import SourceOperator
25
+ from .schema import loads_instance
26
  from .settings_utils import get_constants, get_settings
27
  from .standard import DatasetRecipe
28
  from .task import Task
 
32
  settings = get_settings()
33
 
34
 
35
+ def short_hex_hash(value, length=8):
36
+ h = hashlib.sha256(value.encode()).hexdigest() # Full 64-character hex
37
+ return h[:length]
 
 
 
 
38
 
39
 
40
  def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
 
134
  return load_dataset(card=card, split=split, **kwargs)
135
 
136
 
137
+ def _source_to_dataset(
138
+ source: SourceOperator, split=None, use_cache=False, streaming=False
139
+ ):
140
+ from .dataset import Dataset as UnitxtDataset
141
+
142
+ stream = source()
143
+
144
+ with tempfile.TemporaryDirectory() as dir_to_be_deleted:
145
+ cache_dir = dir_to_be_deleted if not use_cache else None
146
+ ds_builder = UnitxtDataset(
147
+ dataset_name="unitxt",
148
+ config_name="recipe-" + short_hex_hash(source.to_json()),
149
+ hash=hash(source.to_json()),
150
+ version=constants.version,
151
+ cache_dir=cache_dir,
152
+ )
153
+ if split is not None:
154
+ stream = {split: stream[split]}
155
+ ds_builder._generators = stream
156
+
157
+ try:
158
+ ds_builder.download_and_prepare()
159
+
160
+ if streaming:
161
+ return ds_builder.as_streaming_dataset(split=split)
162
+
163
+ return ds_builder.as_dataset(
164
+ split=split, run_post_process=False, verification_mode="no_checks"
165
+ )
166
+ except DatasetGenerationError as e:
167
+ raise e.__cause__
168
+
169
+
170
  def load_dataset(
171
  dataset_query: Optional[str] = None,
172
  split: Optional[str] = None,
173
  streaming: bool = False,
174
+ use_cache: Optional[bool] = False,
175
  **kwargs,
176
  ) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
177
  """Loads dataset.
 
188
  local catalog or name of specific recipe or benchmark in the catalog. For
189
  example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
190
  streaming (bool, False):
191
+ When True yields the data as a stream.
192
+ This is useful when loading very large datasets.
193
+ Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
194
  split (str, optional):
195
  The split of the data to load
196
+ use_cache (bool, optional):
197
+ If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
198
+ If set to False (default), the returned dataset is not cached.
199
+ Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
200
+ Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
201
  **kwargs:
202
  Arguments used to load dataset from provided card, which is not present in local catalog.
203
 
 
221
  """
222
  recipe = load_recipe(dataset_query, **kwargs)
223
 
224
+ dataset = _source_to_dataset(
225
+ source=recipe, split=split, use_cache=use_cache, streaming=streaming
226
+ )
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  frame = inspect.currentframe()
229
  args, _, _, values = inspect.getargvalues(frame)
formats.py CHANGED
@@ -13,6 +13,7 @@ from typing import (
13
 
14
  from .dataclass import OptionalField
15
  from .dict_utils import dict_get
 
16
  from .image_operators import image_to_data_url
17
  from .operator import InstanceOperator
18
  from .settings_utils import get_constants
@@ -25,6 +26,55 @@ class Format(InstanceOperator):
25
  pass
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def apply_capital_new_line_notation(text: str) -> str:
29
  r"""Transforms a given string by applying the Capital New Line Notation.
30
 
 
13
 
14
  from .dataclass import OptionalField
15
  from .dict_utils import dict_get
16
+ from .error_utils import UnitxtError
17
  from .image_operators import image_to_data_url
18
  from .operator import InstanceOperator
19
  from .settings_utils import get_constants
 
26
  pass
27
 
28
 
29
+ class GraniteDocumentsFormat(Format):
30
+ model: str = "ibm-granite/granite-3.1-8b-instruct"
31
+ citations: bool = True
32
+ length: str = "long"
33
+
34
+ _requirements_list = ["transformers"]
35
+
36
+ def prepare(self):
37
+ super().prepare()
38
+ from transformers import AutoTokenizer
39
+
40
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model)
41
+
42
+ def process(
43
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
44
+ ) -> Dict[str, Any]:
45
+ inputs = instance["input_fields"]
46
+ if "question" not in inputs:
47
+ raise UnitxtError(
48
+ "GraniteRAGFormat works only for tasks with field: 'question'"
49
+ )
50
+ if "context" not in inputs and "contexts" not in inputs:
51
+ raise UnitxtError(
52
+ "GraniteRAGFormat works only for tasks with field: 'context' or 'contexts"
53
+ )
54
+
55
+ if "context" in inputs:
56
+ texts = [inputs["context"]]
57
+ if "contexts" in inputs:
58
+ texts = inputs["contexts"]
59
+
60
+ documents = []
61
+ for text in texts:
62
+ documents.append({"title": "", "text": text})
63
+
64
+ question = inputs["question"]
65
+
66
+ instance["source"] = self.tokenizer.apply_chat_template(
67
+ [
68
+ {"role": "user", "content": question},
69
+ ],
70
+ documents=documents,
71
+ controls={"citations": self.citations, "length": self.length},
72
+ add_generation_prompt=True,
73
+ tokenize=False,
74
+ )
75
+ return instance
76
+
77
+
78
  def apply_capital_new_line_notation(text: str) -> str:
79
  r"""Transforms a given string by applying the Capital New Line Notation.
80
 
loaders.py CHANGED
@@ -53,7 +53,7 @@ from typing import (
53
 
54
  import pandas as pd
55
  import requests
56
- from datasets import IterableDatasetDict
57
  from datasets import load_dataset as hf_load_dataset
58
  from huggingface_hub import HfApi
59
  from tqdm import tqdm
@@ -210,7 +210,7 @@ class LoadHF(Loader):
210
  Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
211
  ] = None
212
  revision: Optional[str] = None
213
- streaming: bool = True
214
  filtering_lambda: Optional[str] = None
215
  num_proc: Optional[int] = None
216
  requirements_list: List[str] = OptionalField(default_factory=list)
@@ -221,7 +221,7 @@ class LoadHF(Loader):
221
  self._requirements_list.append(requirement)
222
  super().verify()
223
 
224
- def filter_load(self, dataset):
225
  if not settings.allow_unverified_code:
226
  raise ValueError(
227
  f"{self.__class__.__name__} cannot run use filtering_lambda expression without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE=True."
@@ -229,9 +229,14 @@ class LoadHF(Loader):
229
  logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
230
  return dataset.filter(eval(self.filtering_lambda))
231
 
 
 
 
 
 
232
  def stream_dataset(self):
233
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
234
- if settings.disable_hf_datasets_cache and not self.streaming:
235
  cache_dir = dir_to_be_deleted
236
  else:
237
  cache_dir = None
@@ -242,7 +247,7 @@ class LoadHF(Loader):
242
  data_dir=self.data_dir,
243
  data_files=self.data_files,
244
  revision=self.revision,
245
- streaming=self.streaming,
246
  cache_dir=cache_dir,
247
  split=self.split,
248
  trust_remote_code=settings.allow_unverified_code,
@@ -288,11 +293,8 @@ class LoadHF(Loader):
288
  f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
289
  ) from e
290
 
291
- if self.split is None:
292
- for split in dataset.keys():
293
- dataset[split] = dataset[split].to_iterable_dataset()
294
- else:
295
- dataset = {self.split: dataset.to_iterable_dataset()}
296
 
297
  return dataset
298
 
@@ -824,6 +826,8 @@ class LoadFromHFSpace(LoadHF):
824
  token_env: Optional[str] = None
825
  requirements_list: List[str] = ["huggingface_hub"]
826
 
 
 
827
  def _get_token(self) -> Optional[Union[bool, str]]:
828
  if self.token_env:
829
  token = os.getenv(self.token_env)
@@ -954,70 +958,6 @@ class LoadFromHFSpace(LoadHF):
954
  self.path = self._download_data()
955
  return super().load_data()
956
 
957
- # url: str
958
-
959
- # _requirements_list: List[str] = ["opendatasets"]
960
- # data_classification_policy = ["public"]
961
-
962
- # def verify(self):
963
- # super().verify()
964
- # if not os.path.isfile("kaggle.json"):
965
- # raise MissingKaggleCredentialsError(
966
- # "Please obtain kaggle credentials https://christianjmills.com/posts/kaggle-obtain-api-key-tutorial/ and save them to local ./kaggle.json file"
967
- # )
968
-
969
- # if self.streaming:
970
- # raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
971
-
972
- # def prepare(self):
973
- # super().prepare()
974
- # from opendatasets import download
975
-
976
- # self.downloader = download
977
-
978
- # def load_iterables(self):
979
- # with TemporaryDirectory() as temp_directory:
980
- # self.downloader(self.url, temp_directory)
981
- # return hf_load_dataset(temp_directory, streaming=False)
982
-
983
- # class LoadFromAPI(Loader):
984
- # """Loads data from from API"""
985
-
986
- # urls: Dict[str, str]
987
- # chunksize: int = 100000
988
- # loader_limit: Optional[int] = None
989
- # streaming: bool = False
990
-
991
- # def _maybe_set_classification_policy(self):
992
- # self.set_default_data_classification(["proprietary"], "when loading from API")
993
-
994
- # def load_iterables(self):
995
- self.api_key = os.getenv("SQL_API_KEY", None)
996
- if not self.api_key:
997
- raise ValueError(
998
- "The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector."
999
- )
1000
-
1001
- self.base_headers = {
1002
- "Content-Type": "application/json",
1003
- "accept": "application/json",
1004
- "Authorization": f"Bearer {self.api_key}",
1005
- }
1006
-
1007
- iterables = {}
1008
- for split_name, url in self.urls.items():
1009
- response = requests.get(
1010
- url,
1011
- headers=self.base_headers,
1012
- verify=True,
1013
- )
1014
-
1015
- iterables[split_name] = pd.DataFrame(
1016
- json.loads(response.text)["embeddings"]
1017
- )
1018
-
1019
- return iterables
1020
-
1021
 
1022
  class LoadFromAPI(Loader):
1023
  """Loads data from from API.
 
53
 
54
  import pandas as pd
55
  import requests
56
+ from datasets import DatasetDict, IterableDatasetDict
57
  from datasets import load_dataset as hf_load_dataset
58
  from huggingface_hub import HfApi
59
  from tqdm import tqdm
 
210
  Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
211
  ] = None
212
  revision: Optional[str] = None
213
+ streaming: bool = None
214
  filtering_lambda: Optional[str] = None
215
  num_proc: Optional[int] = None
216
  requirements_list: List[str] = OptionalField(default_factory=list)
 
221
  self._requirements_list.append(requirement)
222
  super().verify()
223
 
224
+ def filter_load(self, dataset: DatasetDict):
225
  if not settings.allow_unverified_code:
226
  raise ValueError(
227
  f"{self.__class__.__name__} cannot run use filtering_lambda expression without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE=True."
 
229
  logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
230
  return dataset.filter(eval(self.filtering_lambda))
231
 
232
+ def is_streaming(self) -> bool:
233
+ if self.streaming is None:
234
+ return settings.stream_hf_datasets_by_default
235
+ return self.streaming
236
+
237
  def stream_dataset(self):
238
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
239
+ if settings.disable_hf_datasets_cache and not self.is_streaming():
240
  cache_dir = dir_to_be_deleted
241
  else:
242
  cache_dir = None
 
247
  data_dir=self.data_dir,
248
  data_files=self.data_files,
249
  revision=self.revision,
250
+ streaming=self.is_streaming(),
251
  cache_dir=cache_dir,
252
  split=self.split,
253
  trust_remote_code=settings.allow_unverified_code,
 
293
  f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
294
  ) from e
295
 
296
+ if self.split is not None:
297
+ dataset = {self.split: dataset}
 
 
 
298
 
299
  return dataset
300
 
 
826
  token_env: Optional[str] = None
827
  requirements_list: List[str] = ["huggingface_hub"]
828
 
829
+ streaming: bool = True
830
+
831
  def _get_token(self) -> Optional[Union[bool, str]]:
832
  if self.token_env:
833
  token = os.getenv(self.token_env)
 
958
  self.path = self._download_data()
959
  return super().load_data()
960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961
 
962
  class LoadFromAPI(Loader):
963
  """Loads data from from API.
metrics.py CHANGED
@@ -1886,7 +1886,6 @@ class RelaxedCorrectness(GlobalMetric):
1886
  "relaxed_augmented_split": [],
1887
  }
1888
  for pred, ref, task_data_i in zip(predictions, references, task_data):
1889
- print(task_data_i)
1890
  type = task_data_i["type"]
1891
  score = self.relaxed_correctness(pred, ref[0])
1892
  score = 1.0 if score else 0.0
 
1886
  "relaxed_augmented_split": [],
1887
  }
1888
  for pred, ref, task_data_i in zip(predictions, references, task_data):
 
1889
  type = task_data_i["type"]
1890
  score = self.relaxed_correctness(pred, ref[0])
1891
  score = 1.0 if score else 0.0
operators.py CHANGED
@@ -67,6 +67,7 @@ from .artifact import Artifact, fetch_artifact
67
  from .dataclass import NonPositionalField, OptionalField
68
  from .deprecation_utils import deprecation
69
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
 
70
  from .generator_utils import ReusableGenerator
71
  from .operator import (
72
  InstanceOperator,
@@ -84,7 +85,7 @@ from .operator import (
84
  from .random_utils import new_random_generator
85
  from .settings_utils import get_settings
86
  from .stream import DynamicStream, Stream
87
- from .text_utils import nested_tuple_to_string
88
  from .type_utils import isoftype
89
  from .utils import (
90
  LRUCache,
@@ -1476,6 +1477,113 @@ class Intersect(FieldOperator):
1476
  return [e for e in value if e in self.allowed_values]
1477
 
1478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1479
  class RemoveValues(FieldOperator):
1480
  """Removes elements in a field, which must be a list, using a given list of unallowed.
1481
 
@@ -2243,6 +2351,102 @@ class CollateInstances(StreamOperator):
2243
  )
2244
 
2245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2246
  class WikipediaFetcher(FieldOperator):
2247
  mode: Literal["summary", "text"] = "text"
2248
  _requirements_list = ["Wikipedia-API"]
 
67
  from .dataclass import NonPositionalField, OptionalField
68
  from .deprecation_utils import deprecation
69
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
70
+ from .error_utils import UnitxtError
71
  from .generator_utils import ReusableGenerator
72
  from .operator import (
73
  InstanceOperator,
 
85
  from .random_utils import new_random_generator
86
  from .settings_utils import get_settings
87
  from .stream import DynamicStream, Stream
88
+ from .text_utils import nested_tuple_to_string, to_pretty_string
89
  from .type_utils import isoftype
90
  from .utils import (
91
  LRUCache,
 
1477
  return [e for e in value if e in self.allowed_values]
1478
 
1479
 
1480
+ class IntersectCorrespondingFields(InstanceOperator):
1481
+ """Intersects the value of a field, which must be a list, with a given list , and removes corresponding elements from other list fields.
1482
+
1483
+ For example:
1484
+
1485
+ Assume the instances contain a field of 'labels' and a field with the labels' corresponding 'positions' in the text.
1486
+
1487
+ IntersectCorrespondingFields(field="label",
1488
+ allowed_values=["b", "f"],
1489
+ corresponding_fields_to_intersect=["position"])
1490
+
1491
+ would keep only "b" and "f" values in 'labels' field and
1492
+ their respective values in the 'position' field.
1493
+ (All other fields are not effected)
1494
+
1495
+ Given this input:
1496
+
1497
+ [
1498
+ {"label": ["a", "b"],"position": [0,1],"other" : "not"},
1499
+ {"label": ["a", "c", "d"], "position": [0,1,2], "other" : "relevant"},
1500
+ {"label": ["a", "b", "f"], "position": [0,1,2], "other" : "field"}
1501
+ ]
1502
+
1503
+ So the output would be:
1504
+ [
1505
+ {"label": ["b"], "position":[1],"other" : "not"},
1506
+ {"label": [], "position": [], "other" : "relevant"},
1507
+ {"label": ["b", "f"],"position": [1,2], "other" : "field"},
1508
+ ]
1509
+
1510
+ Args:
1511
+ field - the field to intersected (must contain list values)
1512
+ allowed_values (list) - list of values to keep
1513
+ corresponding_fields_to_intersect (list) - additional list fields from which values
1514
+ are removed based the corresponding indices of values removed from the 'field'
1515
+ """
1516
+
1517
+ field: str
1518
+ allowed_values: List[str]
1519
+ corresponding_fields_to_intersect: List[str]
1520
+
1521
+ def verify(self):
1522
+ super().verify()
1523
+
1524
+ if not isinstance(self.allowed_values, list):
1525
+ raise ValueError(
1526
+ f"The allowed_field_values is not a type list but '{type(self.allowed_field_values)}'"
1527
+ )
1528
+
1529
+ def process(
1530
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
1531
+ ) -> Dict[str, Any]:
1532
+ if self.field not in instance:
1533
+ raise ValueError(
1534
+ f"Field '{self.field}' is not in provided instance.\n"
1535
+ + to_pretty_string(instance)
1536
+ )
1537
+
1538
+ for corresponding_field in self.corresponding_fields_to_intersect:
1539
+ if corresponding_field not in instance:
1540
+ raise ValueError(
1541
+ f"Field '{corresponding_field}' is not in provided instance.\n"
1542
+ + to_pretty_string(instance)
1543
+ )
1544
+
1545
+ if not isinstance(instance[self.field], list):
1546
+ raise ValueError(
1547
+ f"Value of field '{self.field}' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\n"
1548
+ + to_pretty_string(instance, keys=[self.field])
1549
+ )
1550
+
1551
+ num_values_in_field = len(instance[self.field])
1552
+
1553
+ if set(self.allowed_values) == set(instance[self.field]):
1554
+ return instance
1555
+
1556
+ indices_to_keep = [
1557
+ i
1558
+ for i, value in enumerate(instance[self.field])
1559
+ if value in set(self.allowed_values)
1560
+ ]
1561
+
1562
+ result_instance = {}
1563
+ for field_name, field_value in instance.items():
1564
+ if (
1565
+ field_name in self.corresponding_fields_to_intersect
1566
+ or field_name == self.field
1567
+ ):
1568
+ if not isinstance(field_value, list):
1569
+ raise ValueError(
1570
+ f"Value of field '{field_name}' is not a list, IntersectCorrespondingFields can not intersect with allowed values."
1571
+ )
1572
+ if len(field_value) != num_values_in_field:
1573
+ raise ValueError(
1574
+ f"Number of elements in field '{field_name}' is not the same as the number of elements in field '{self.field}' so the IntersectCorrespondingFields can not remove corresponding values.\n"
1575
+ + to_pretty_string(instance, keys=[self.field, field_name])
1576
+ )
1577
+ result_instance[field_name] = [
1578
+ value
1579
+ for index, value in enumerate(field_value)
1580
+ if index in indices_to_keep
1581
+ ]
1582
+ else:
1583
+ result_instance[field_name] = field_value
1584
+ return result_instance
1585
+
1586
+
1587
  class RemoveValues(FieldOperator):
1588
  """Removes elements in a field, which must be a list, using a given list of unallowed.
1589
 
 
2351
  )
2352
 
2353
 
2354
+ class CollateInstancesByField(StreamOperator):
2355
+ """Groups a list of instances by a specified field, aggregates specified fields into lists, and ensures consistency for all other non-aggregated fields.
2356
+
2357
+ Args:
2358
+ by_field str: the name of the field to group data by.
2359
+ aggregate_fields list(str): the field names to aggregate into lists.
2360
+
2361
+ Returns:
2362
+ A stream of instances grouped and aggregated by the specified field.
2363
+
2364
+ Raises:
2365
+ UnitxtError: If non-aggregate fields have inconsistent values.
2366
+
2367
+ Example:
2368
+ Collate the instances based on field "category" and aggregate fields "value" and "id".
2369
+
2370
+ CollateInstancesByField(by_field="category", aggregate_fields=["value", "id"])
2371
+
2372
+ given input:
2373
+ [
2374
+ {"id": 1, "category": "A", "value": 10", "flag" : True},
2375
+ {"id": 2, "category": "B", "value": 20", "flag" : False},
2376
+ {"id": 3, "category": "A", "value": 30", "flag" : True},
2377
+ {"id": 4, "category": "B", "value": 40", "flag" : False}
2378
+ ]
2379
+
2380
+ the output is:
2381
+ [
2382
+ {"category": "A", "id": [1, 3], "value": [10, 30], "info": True},
2383
+ {"category": "B", "id": [2, 4], "value": [20, 40], "info": False}
2384
+ ]
2385
+
2386
+ Note that the "flag" field is not aggregated, and must be the same
2387
+ in all instances in the same category, or an error is raised.
2388
+ """
2389
+
2390
+ by_field: str = NonPositionalField(required=True)
2391
+ aggregate_fields: List[str] = NonPositionalField(required=True)
2392
+
2393
+ def prepare(self):
2394
+ super().prepare()
2395
+
2396
+ def verify(self):
2397
+ super().verify()
2398
+ if not isinstance(self.by_field, str):
2399
+ raise UnitxtError(
2400
+ f"The 'by_field' value is not a string but '{type(self.by_field)}'"
2401
+ )
2402
+
2403
+ if not isinstance(self.aggregate_fields, list):
2404
+ raise UnitxtError(
2405
+ f"The 'allowed_field_values' is not a list but '{type(self.aggregate_fields)}'"
2406
+ )
2407
+
2408
+ def process(self, stream: Stream, stream_name: Optional[str] = None):
2409
+ grouped_data = {}
2410
+
2411
+ for instance in stream:
2412
+ if self.by_field not in instance:
2413
+ raise UnitxtError(
2414
+ f"The field '{self.by_field}' specified by CollateInstancesByField's 'by_field' argument is not found in instance."
2415
+ )
2416
+ for k in self.aggregate_fields:
2417
+ if k not in instance:
2418
+ raise UnitxtError(
2419
+ f"The field '{k}' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance."
2420
+ )
2421
+ key = instance[self.by_field]
2422
+
2423
+ if key not in grouped_data:
2424
+ grouped_data[key] = {
2425
+ k: v for k, v in instance.items() if k not in self.aggregate_fields
2426
+ }
2427
+ # Add empty lists for fields to aggregate
2428
+ for agg_field in self.aggregate_fields:
2429
+ if agg_field in instance:
2430
+ grouped_data[key][agg_field] = []
2431
+
2432
+ for k, v in instance.items():
2433
+ # Merge classification policy list across instance with same key
2434
+ if k == "data_classification_policy" and instance[k]:
2435
+ grouped_data[key][k] = sorted(set(grouped_data[key][k] + v))
2436
+ # Check consistency for all non-aggregate fields
2437
+ elif k != self.by_field and k not in self.aggregate_fields:
2438
+ if k in grouped_data[key] and grouped_data[key][k] != v:
2439
+ raise ValueError(
2440
+ f"Inconsistent value for field '{k}' in group '{key}': "
2441
+ f"'{grouped_data[key][k]}' vs '{v}'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances."
2442
+ )
2443
+ # Aggregate fields
2444
+ elif k in self.aggregate_fields:
2445
+ grouped_data[key][k].append(instance[k])
2446
+
2447
+ yield from grouped_data.values()
2448
+
2449
+
2450
  class WikipediaFetcher(FieldOperator):
2451
  mode: Literal["summary", "text"] = "text"
2452
  _requirements_list = ["Wikipedia-API"]
settings_utils.py CHANGED
@@ -149,8 +149,10 @@ if Settings.is_uninitilized():
149
  settings.skip_artifacts_prepare_and_verify = (bool, False)
150
  settings.data_classification_policy = None
151
  settings.mock_inference_mode = (bool, False)
152
- settings.disable_hf_datasets_cache = (bool, True)
153
- settings.loader_cache_size = (int, 1)
 
 
154
  settings.task_data_as_text = (bool, True)
155
  settings.default_provider = "watsonx"
156
  settings.default_format = None
 
149
  settings.skip_artifacts_prepare_and_verify = (bool, False)
150
  settings.data_classification_policy = None
151
  settings.mock_inference_mode = (bool, False)
152
+ settings.disable_hf_datasets_cache = (bool, False)
153
+ settings.stream_hf_datasets_by_default = (bool, False)
154
+
155
+ settings.loader_cache_size = (int, 10)
156
  settings.task_data_as_text = (bool, True)
157
  settings.default_provider = "watsonx"
158
  settings.default_format = None
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.17.1"
 
1
+ version = "1.17.2"