Upload folder using huggingface_hub
Browse files- api.py +52 -27
- formats.py +50 -0
- loaders.py +14 -74
- metrics.py +0 -1
- operators.py +205 -1
- settings_utils.py +4 -2
- version.py +1 -1
api.py
CHANGED
@@ -1,10 +1,13 @@
|
|
|
|
1 |
import inspect
|
2 |
import json
|
|
|
3 |
from datetime import datetime
|
4 |
from functools import lru_cache
|
5 |
from typing import Any, Dict, List, Optional, Union
|
6 |
|
7 |
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
|
8 |
|
9 |
from .artifact import fetch_artifact
|
10 |
from .card import TaskCard
|
@@ -19,7 +22,7 @@ from .loaders import LoadFromDictionary
|
|
19 |
from .logging_utils import get_logger
|
20 |
from .metric_utils import EvaluationResults, _compute, _inference_post_process
|
21 |
from .operator import SourceOperator
|
22 |
-
from .schema import
|
23 |
from .settings_utils import get_constants, get_settings
|
24 |
from .standard import DatasetRecipe
|
25 |
from .task import Task
|
@@ -29,13 +32,9 @@ constants = get_constants()
|
|
29 |
settings = get_settings()
|
30 |
|
31 |
|
32 |
-
def
|
33 |
-
|
34 |
-
|
35 |
-
), "source must be a SourceOperator or a string"
|
36 |
-
if isinstance(source, str):
|
37 |
-
source, _ = fetch_artifact(source)
|
38 |
-
return source().to_dataset()
|
39 |
|
40 |
|
41 |
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
|
@@ -135,11 +134,44 @@ def create_dataset(
|
|
135 |
return load_dataset(card=card, split=split, **kwargs)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
def load_dataset(
|
139 |
dataset_query: Optional[str] = None,
|
140 |
split: Optional[str] = None,
|
141 |
streaming: bool = False,
|
142 |
-
|
143 |
**kwargs,
|
144 |
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
|
145 |
"""Loads dataset.
|
@@ -156,11 +188,16 @@ def load_dataset(
|
|
156 |
local catalog or name of specific recipe or benchmark in the catalog. For
|
157 |
example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
|
158 |
streaming (bool, False):
|
159 |
-
When True yields the data as
|
|
|
|
|
160 |
split (str, optional):
|
161 |
The split of the data to load
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
164 |
**kwargs:
|
165 |
Arguments used to load dataset from provided card, which is not present in local catalog.
|
166 |
|
@@ -184,21 +221,9 @@ def load_dataset(
|
|
184 |
"""
|
185 |
recipe = load_recipe(dataset_query, **kwargs)
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
if disable_cache is None:
|
192 |
-
disable_cache = settings.disable_hf_datasets_cache
|
193 |
-
|
194 |
-
if streaming:
|
195 |
-
dataset = stream.to_iterable_dataset(
|
196 |
-
features=UNITXT_DATASET_SCHEMA,
|
197 |
-
).map(loads_instance, batched=True)
|
198 |
-
else:
|
199 |
-
dataset = stream.to_dataset(
|
200 |
-
features=UNITXT_DATASET_SCHEMA, disable_cache=disable_cache
|
201 |
-
).with_transform(loads_instance)
|
202 |
|
203 |
frame = inspect.currentframe()
|
204 |
args, _, _, values = inspect.getargvalues(frame)
|
|
|
1 |
+
import hashlib
|
2 |
import inspect
|
3 |
import json
|
4 |
+
import tempfile
|
5 |
from datetime import datetime
|
6 |
from functools import lru_cache
|
7 |
from typing import Any, Dict, List, Optional, Union
|
8 |
|
9 |
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
10 |
+
from datasets.exceptions import DatasetGenerationError
|
11 |
|
12 |
from .artifact import fetch_artifact
|
13 |
from .card import TaskCard
|
|
|
22 |
from .logging_utils import get_logger
|
23 |
from .metric_utils import EvaluationResults, _compute, _inference_post_process
|
24 |
from .operator import SourceOperator
|
25 |
+
from .schema import loads_instance
|
26 |
from .settings_utils import get_constants, get_settings
|
27 |
from .standard import DatasetRecipe
|
28 |
from .task import Task
|
|
|
32 |
settings = get_settings()
|
33 |
|
34 |
|
35 |
+
def short_hex_hash(value, length=8):
|
36 |
+
h = hashlib.sha256(value.encode()).hexdigest() # Full 64-character hex
|
37 |
+
return h[:length]
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
|
|
|
134 |
return load_dataset(card=card, split=split, **kwargs)
|
135 |
|
136 |
|
137 |
+
def _source_to_dataset(
|
138 |
+
source: SourceOperator, split=None, use_cache=False, streaming=False
|
139 |
+
):
|
140 |
+
from .dataset import Dataset as UnitxtDataset
|
141 |
+
|
142 |
+
stream = source()
|
143 |
+
|
144 |
+
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
145 |
+
cache_dir = dir_to_be_deleted if not use_cache else None
|
146 |
+
ds_builder = UnitxtDataset(
|
147 |
+
dataset_name="unitxt",
|
148 |
+
config_name="recipe-" + short_hex_hash(source.to_json()),
|
149 |
+
hash=hash(source.to_json()),
|
150 |
+
version=constants.version,
|
151 |
+
cache_dir=cache_dir,
|
152 |
+
)
|
153 |
+
if split is not None:
|
154 |
+
stream = {split: stream[split]}
|
155 |
+
ds_builder._generators = stream
|
156 |
+
|
157 |
+
try:
|
158 |
+
ds_builder.download_and_prepare()
|
159 |
+
|
160 |
+
if streaming:
|
161 |
+
return ds_builder.as_streaming_dataset(split=split)
|
162 |
+
|
163 |
+
return ds_builder.as_dataset(
|
164 |
+
split=split, run_post_process=False, verification_mode="no_checks"
|
165 |
+
)
|
166 |
+
except DatasetGenerationError as e:
|
167 |
+
raise e.__cause__
|
168 |
+
|
169 |
+
|
170 |
def load_dataset(
|
171 |
dataset_query: Optional[str] = None,
|
172 |
split: Optional[str] = None,
|
173 |
streaming: bool = False,
|
174 |
+
use_cache: Optional[bool] = False,
|
175 |
**kwargs,
|
176 |
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
|
177 |
"""Loads dataset.
|
|
|
188 |
local catalog or name of specific recipe or benchmark in the catalog. For
|
189 |
example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
|
190 |
streaming (bool, False):
|
191 |
+
When True yields the data as a stream.
|
192 |
+
This is useful when loading very large datasets.
|
193 |
+
Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
|
194 |
split (str, optional):
|
195 |
The split of the data to load
|
196 |
+
use_cache (bool, optional):
|
197 |
+
If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
|
198 |
+
If set to False (default), the returned dataset is not cached.
|
199 |
+
Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
|
200 |
+
Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
|
201 |
**kwargs:
|
202 |
Arguments used to load dataset from provided card, which is not present in local catalog.
|
203 |
|
|
|
221 |
"""
|
222 |
recipe = load_recipe(dataset_query, **kwargs)
|
223 |
|
224 |
+
dataset = _source_to_dataset(
|
225 |
+
source=recipe, split=split, use_cache=use_cache, streaming=streaming
|
226 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
frame = inspect.currentframe()
|
229 |
args, _, _, values = inspect.getargvalues(frame)
|
formats.py
CHANGED
@@ -13,6 +13,7 @@ from typing import (
|
|
13 |
|
14 |
from .dataclass import OptionalField
|
15 |
from .dict_utils import dict_get
|
|
|
16 |
from .image_operators import image_to_data_url
|
17 |
from .operator import InstanceOperator
|
18 |
from .settings_utils import get_constants
|
@@ -25,6 +26,55 @@ class Format(InstanceOperator):
|
|
25 |
pass
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def apply_capital_new_line_notation(text: str) -> str:
|
29 |
r"""Transforms a given string by applying the Capital New Line Notation.
|
30 |
|
|
|
13 |
|
14 |
from .dataclass import OptionalField
|
15 |
from .dict_utils import dict_get
|
16 |
+
from .error_utils import UnitxtError
|
17 |
from .image_operators import image_to_data_url
|
18 |
from .operator import InstanceOperator
|
19 |
from .settings_utils import get_constants
|
|
|
26 |
pass
|
27 |
|
28 |
|
29 |
+
class GraniteDocumentsFormat(Format):
|
30 |
+
model: str = "ibm-granite/granite-3.1-8b-instruct"
|
31 |
+
citations: bool = True
|
32 |
+
length: str = "long"
|
33 |
+
|
34 |
+
_requirements_list = ["transformers"]
|
35 |
+
|
36 |
+
def prepare(self):
|
37 |
+
super().prepare()
|
38 |
+
from transformers import AutoTokenizer
|
39 |
+
|
40 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
|
41 |
+
|
42 |
+
def process(
|
43 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
44 |
+
) -> Dict[str, Any]:
|
45 |
+
inputs = instance["input_fields"]
|
46 |
+
if "question" not in inputs:
|
47 |
+
raise UnitxtError(
|
48 |
+
"GraniteRAGFormat works only for tasks with field: 'question'"
|
49 |
+
)
|
50 |
+
if "context" not in inputs and "contexts" not in inputs:
|
51 |
+
raise UnitxtError(
|
52 |
+
"GraniteRAGFormat works only for tasks with field: 'context' or 'contexts"
|
53 |
+
)
|
54 |
+
|
55 |
+
if "context" in inputs:
|
56 |
+
texts = [inputs["context"]]
|
57 |
+
if "contexts" in inputs:
|
58 |
+
texts = inputs["contexts"]
|
59 |
+
|
60 |
+
documents = []
|
61 |
+
for text in texts:
|
62 |
+
documents.append({"title": "", "text": text})
|
63 |
+
|
64 |
+
question = inputs["question"]
|
65 |
+
|
66 |
+
instance["source"] = self.tokenizer.apply_chat_template(
|
67 |
+
[
|
68 |
+
{"role": "user", "content": question},
|
69 |
+
],
|
70 |
+
documents=documents,
|
71 |
+
controls={"citations": self.citations, "length": self.length},
|
72 |
+
add_generation_prompt=True,
|
73 |
+
tokenize=False,
|
74 |
+
)
|
75 |
+
return instance
|
76 |
+
|
77 |
+
|
78 |
def apply_capital_new_line_notation(text: str) -> str:
|
79 |
r"""Transforms a given string by applying the Capital New Line Notation.
|
80 |
|
loaders.py
CHANGED
@@ -53,7 +53,7 @@ from typing import (
|
|
53 |
|
54 |
import pandas as pd
|
55 |
import requests
|
56 |
-
from datasets import IterableDatasetDict
|
57 |
from datasets import load_dataset as hf_load_dataset
|
58 |
from huggingface_hub import HfApi
|
59 |
from tqdm import tqdm
|
@@ -210,7 +210,7 @@ class LoadHF(Loader):
|
|
210 |
Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
|
211 |
] = None
|
212 |
revision: Optional[str] = None
|
213 |
-
streaming: bool =
|
214 |
filtering_lambda: Optional[str] = None
|
215 |
num_proc: Optional[int] = None
|
216 |
requirements_list: List[str] = OptionalField(default_factory=list)
|
@@ -221,7 +221,7 @@ class LoadHF(Loader):
|
|
221 |
self._requirements_list.append(requirement)
|
222 |
super().verify()
|
223 |
|
224 |
-
def filter_load(self, dataset):
|
225 |
if not settings.allow_unverified_code:
|
226 |
raise ValueError(
|
227 |
f"{self.__class__.__name__} cannot run use filtering_lambda expression without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE=True."
|
@@ -229,9 +229,14 @@ class LoadHF(Loader):
|
|
229 |
logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
|
230 |
return dataset.filter(eval(self.filtering_lambda))
|
231 |
|
|
|
|
|
|
|
|
|
|
|
232 |
def stream_dataset(self):
|
233 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
234 |
-
if settings.disable_hf_datasets_cache and not self.
|
235 |
cache_dir = dir_to_be_deleted
|
236 |
else:
|
237 |
cache_dir = None
|
@@ -242,7 +247,7 @@ class LoadHF(Loader):
|
|
242 |
data_dir=self.data_dir,
|
243 |
data_files=self.data_files,
|
244 |
revision=self.revision,
|
245 |
-
streaming=self.
|
246 |
cache_dir=cache_dir,
|
247 |
split=self.split,
|
248 |
trust_remote_code=settings.allow_unverified_code,
|
@@ -288,11 +293,8 @@ class LoadHF(Loader):
|
|
288 |
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
|
289 |
) from e
|
290 |
|
291 |
-
if self.split is None:
|
292 |
-
|
293 |
-
dataset[split] = dataset[split].to_iterable_dataset()
|
294 |
-
else:
|
295 |
-
dataset = {self.split: dataset.to_iterable_dataset()}
|
296 |
|
297 |
return dataset
|
298 |
|
@@ -824,6 +826,8 @@ class LoadFromHFSpace(LoadHF):
|
|
824 |
token_env: Optional[str] = None
|
825 |
requirements_list: List[str] = ["huggingface_hub"]
|
826 |
|
|
|
|
|
827 |
def _get_token(self) -> Optional[Union[bool, str]]:
|
828 |
if self.token_env:
|
829 |
token = os.getenv(self.token_env)
|
@@ -954,70 +958,6 @@ class LoadFromHFSpace(LoadHF):
|
|
954 |
self.path = self._download_data()
|
955 |
return super().load_data()
|
956 |
|
957 |
-
# url: str
|
958 |
-
|
959 |
-
# _requirements_list: List[str] = ["opendatasets"]
|
960 |
-
# data_classification_policy = ["public"]
|
961 |
-
|
962 |
-
# def verify(self):
|
963 |
-
# super().verify()
|
964 |
-
# if not os.path.isfile("kaggle.json"):
|
965 |
-
# raise MissingKaggleCredentialsError(
|
966 |
-
# "Please obtain kaggle credentials https://christianjmills.com/posts/kaggle-obtain-api-key-tutorial/ and save them to local ./kaggle.json file"
|
967 |
-
# )
|
968 |
-
|
969 |
-
# if self.streaming:
|
970 |
-
# raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
|
971 |
-
|
972 |
-
# def prepare(self):
|
973 |
-
# super().prepare()
|
974 |
-
# from opendatasets import download
|
975 |
-
|
976 |
-
# self.downloader = download
|
977 |
-
|
978 |
-
# def load_iterables(self):
|
979 |
-
# with TemporaryDirectory() as temp_directory:
|
980 |
-
# self.downloader(self.url, temp_directory)
|
981 |
-
# return hf_load_dataset(temp_directory, streaming=False)
|
982 |
-
|
983 |
-
# class LoadFromAPI(Loader):
|
984 |
-
# """Loads data from from API"""
|
985 |
-
|
986 |
-
# urls: Dict[str, str]
|
987 |
-
# chunksize: int = 100000
|
988 |
-
# loader_limit: Optional[int] = None
|
989 |
-
# streaming: bool = False
|
990 |
-
|
991 |
-
# def _maybe_set_classification_policy(self):
|
992 |
-
# self.set_default_data_classification(["proprietary"], "when loading from API")
|
993 |
-
|
994 |
-
# def load_iterables(self):
|
995 |
-
self.api_key = os.getenv("SQL_API_KEY", None)
|
996 |
-
if not self.api_key:
|
997 |
-
raise ValueError(
|
998 |
-
"The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector."
|
999 |
-
)
|
1000 |
-
|
1001 |
-
self.base_headers = {
|
1002 |
-
"Content-Type": "application/json",
|
1003 |
-
"accept": "application/json",
|
1004 |
-
"Authorization": f"Bearer {self.api_key}",
|
1005 |
-
}
|
1006 |
-
|
1007 |
-
iterables = {}
|
1008 |
-
for split_name, url in self.urls.items():
|
1009 |
-
response = requests.get(
|
1010 |
-
url,
|
1011 |
-
headers=self.base_headers,
|
1012 |
-
verify=True,
|
1013 |
-
)
|
1014 |
-
|
1015 |
-
iterables[split_name] = pd.DataFrame(
|
1016 |
-
json.loads(response.text)["embeddings"]
|
1017 |
-
)
|
1018 |
-
|
1019 |
-
return iterables
|
1020 |
-
|
1021 |
|
1022 |
class LoadFromAPI(Loader):
|
1023 |
"""Loads data from from API.
|
|
|
53 |
|
54 |
import pandas as pd
|
55 |
import requests
|
56 |
+
from datasets import DatasetDict, IterableDatasetDict
|
57 |
from datasets import load_dataset as hf_load_dataset
|
58 |
from huggingface_hub import HfApi
|
59 |
from tqdm import tqdm
|
|
|
210 |
Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
|
211 |
] = None
|
212 |
revision: Optional[str] = None
|
213 |
+
streaming: bool = None
|
214 |
filtering_lambda: Optional[str] = None
|
215 |
num_proc: Optional[int] = None
|
216 |
requirements_list: List[str] = OptionalField(default_factory=list)
|
|
|
221 |
self._requirements_list.append(requirement)
|
222 |
super().verify()
|
223 |
|
224 |
+
def filter_load(self, dataset: DatasetDict):
|
225 |
if not settings.allow_unverified_code:
|
226 |
raise ValueError(
|
227 |
f"{self.__class__.__name__} cannot run use filtering_lambda expression without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE=True."
|
|
|
229 |
logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
|
230 |
return dataset.filter(eval(self.filtering_lambda))
|
231 |
|
232 |
+
def is_streaming(self) -> bool:
|
233 |
+
if self.streaming is None:
|
234 |
+
return settings.stream_hf_datasets_by_default
|
235 |
+
return self.streaming
|
236 |
+
|
237 |
def stream_dataset(self):
|
238 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
239 |
+
if settings.disable_hf_datasets_cache and not self.is_streaming():
|
240 |
cache_dir = dir_to_be_deleted
|
241 |
else:
|
242 |
cache_dir = None
|
|
|
247 |
data_dir=self.data_dir,
|
248 |
data_files=self.data_files,
|
249 |
revision=self.revision,
|
250 |
+
streaming=self.is_streaming(),
|
251 |
cache_dir=cache_dir,
|
252 |
split=self.split,
|
253 |
trust_remote_code=settings.allow_unverified_code,
|
|
|
293 |
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
|
294 |
) from e
|
295 |
|
296 |
+
if self.split is not None:
|
297 |
+
dataset = {self.split: dataset}
|
|
|
|
|
|
|
298 |
|
299 |
return dataset
|
300 |
|
|
|
826 |
token_env: Optional[str] = None
|
827 |
requirements_list: List[str] = ["huggingface_hub"]
|
828 |
|
829 |
+
streaming: bool = True
|
830 |
+
|
831 |
def _get_token(self) -> Optional[Union[bool, str]]:
|
832 |
if self.token_env:
|
833 |
token = os.getenv(self.token_env)
|
|
|
958 |
self.path = self._download_data()
|
959 |
return super().load_data()
|
960 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
961 |
|
962 |
class LoadFromAPI(Loader):
|
963 |
"""Loads data from from API.
|
metrics.py
CHANGED
@@ -1886,7 +1886,6 @@ class RelaxedCorrectness(GlobalMetric):
|
|
1886 |
"relaxed_augmented_split": [],
|
1887 |
}
|
1888 |
for pred, ref, task_data_i in zip(predictions, references, task_data):
|
1889 |
-
print(task_data_i)
|
1890 |
type = task_data_i["type"]
|
1891 |
score = self.relaxed_correctness(pred, ref[0])
|
1892 |
score = 1.0 if score else 0.0
|
|
|
1886 |
"relaxed_augmented_split": [],
|
1887 |
}
|
1888 |
for pred, ref, task_data_i in zip(predictions, references, task_data):
|
|
|
1889 |
type = task_data_i["type"]
|
1890 |
score = self.relaxed_correctness(pred, ref[0])
|
1891 |
score = 1.0 if score else 0.0
|
operators.py
CHANGED
@@ -67,6 +67,7 @@ from .artifact import Artifact, fetch_artifact
|
|
67 |
from .dataclass import NonPositionalField, OptionalField
|
68 |
from .deprecation_utils import deprecation
|
69 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
|
|
70 |
from .generator_utils import ReusableGenerator
|
71 |
from .operator import (
|
72 |
InstanceOperator,
|
@@ -84,7 +85,7 @@ from .operator import (
|
|
84 |
from .random_utils import new_random_generator
|
85 |
from .settings_utils import get_settings
|
86 |
from .stream import DynamicStream, Stream
|
87 |
-
from .text_utils import nested_tuple_to_string
|
88 |
from .type_utils import isoftype
|
89 |
from .utils import (
|
90 |
LRUCache,
|
@@ -1476,6 +1477,113 @@ class Intersect(FieldOperator):
|
|
1476 |
return [e for e in value if e in self.allowed_values]
|
1477 |
|
1478 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1479 |
class RemoveValues(FieldOperator):
|
1480 |
"""Removes elements in a field, which must be a list, using a given list of unallowed.
|
1481 |
|
@@ -2243,6 +2351,102 @@ class CollateInstances(StreamOperator):
|
|
2243 |
)
|
2244 |
|
2245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2246 |
class WikipediaFetcher(FieldOperator):
|
2247 |
mode: Literal["summary", "text"] = "text"
|
2248 |
_requirements_list = ["Wikipedia-API"]
|
|
|
67 |
from .dataclass import NonPositionalField, OptionalField
|
68 |
from .deprecation_utils import deprecation
|
69 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
70 |
+
from .error_utils import UnitxtError
|
71 |
from .generator_utils import ReusableGenerator
|
72 |
from .operator import (
|
73 |
InstanceOperator,
|
|
|
85 |
from .random_utils import new_random_generator
|
86 |
from .settings_utils import get_settings
|
87 |
from .stream import DynamicStream, Stream
|
88 |
+
from .text_utils import nested_tuple_to_string, to_pretty_string
|
89 |
from .type_utils import isoftype
|
90 |
from .utils import (
|
91 |
LRUCache,
|
|
|
1477 |
return [e for e in value if e in self.allowed_values]
|
1478 |
|
1479 |
|
1480 |
+
class IntersectCorrespondingFields(InstanceOperator):
|
1481 |
+
"""Intersects the value of a field, which must be a list, with a given list , and removes corresponding elements from other list fields.
|
1482 |
+
|
1483 |
+
For example:
|
1484 |
+
|
1485 |
+
Assume the instances contain a field of 'labels' and a field with the labels' corresponding 'positions' in the text.
|
1486 |
+
|
1487 |
+
IntersectCorrespondingFields(field="label",
|
1488 |
+
allowed_values=["b", "f"],
|
1489 |
+
corresponding_fields_to_intersect=["position"])
|
1490 |
+
|
1491 |
+
would keep only "b" and "f" values in 'labels' field and
|
1492 |
+
their respective values in the 'position' field.
|
1493 |
+
(All other fields are not effected)
|
1494 |
+
|
1495 |
+
Given this input:
|
1496 |
+
|
1497 |
+
[
|
1498 |
+
{"label": ["a", "b"],"position": [0,1],"other" : "not"},
|
1499 |
+
{"label": ["a", "c", "d"], "position": [0,1,2], "other" : "relevant"},
|
1500 |
+
{"label": ["a", "b", "f"], "position": [0,1,2], "other" : "field"}
|
1501 |
+
]
|
1502 |
+
|
1503 |
+
So the output would be:
|
1504 |
+
[
|
1505 |
+
{"label": ["b"], "position":[1],"other" : "not"},
|
1506 |
+
{"label": [], "position": [], "other" : "relevant"},
|
1507 |
+
{"label": ["b", "f"],"position": [1,2], "other" : "field"},
|
1508 |
+
]
|
1509 |
+
|
1510 |
+
Args:
|
1511 |
+
field - the field to intersected (must contain list values)
|
1512 |
+
allowed_values (list) - list of values to keep
|
1513 |
+
corresponding_fields_to_intersect (list) - additional list fields from which values
|
1514 |
+
are removed based the corresponding indices of values removed from the 'field'
|
1515 |
+
"""
|
1516 |
+
|
1517 |
+
field: str
|
1518 |
+
allowed_values: List[str]
|
1519 |
+
corresponding_fields_to_intersect: List[str]
|
1520 |
+
|
1521 |
+
def verify(self):
|
1522 |
+
super().verify()
|
1523 |
+
|
1524 |
+
if not isinstance(self.allowed_values, list):
|
1525 |
+
raise ValueError(
|
1526 |
+
f"The allowed_field_values is not a type list but '{type(self.allowed_field_values)}'"
|
1527 |
+
)
|
1528 |
+
|
1529 |
+
def process(
|
1530 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
1531 |
+
) -> Dict[str, Any]:
|
1532 |
+
if self.field not in instance:
|
1533 |
+
raise ValueError(
|
1534 |
+
f"Field '{self.field}' is not in provided instance.\n"
|
1535 |
+
+ to_pretty_string(instance)
|
1536 |
+
)
|
1537 |
+
|
1538 |
+
for corresponding_field in self.corresponding_fields_to_intersect:
|
1539 |
+
if corresponding_field not in instance:
|
1540 |
+
raise ValueError(
|
1541 |
+
f"Field '{corresponding_field}' is not in provided instance.\n"
|
1542 |
+
+ to_pretty_string(instance)
|
1543 |
+
)
|
1544 |
+
|
1545 |
+
if not isinstance(instance[self.field], list):
|
1546 |
+
raise ValueError(
|
1547 |
+
f"Value of field '{self.field}' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\n"
|
1548 |
+
+ to_pretty_string(instance, keys=[self.field])
|
1549 |
+
)
|
1550 |
+
|
1551 |
+
num_values_in_field = len(instance[self.field])
|
1552 |
+
|
1553 |
+
if set(self.allowed_values) == set(instance[self.field]):
|
1554 |
+
return instance
|
1555 |
+
|
1556 |
+
indices_to_keep = [
|
1557 |
+
i
|
1558 |
+
for i, value in enumerate(instance[self.field])
|
1559 |
+
if value in set(self.allowed_values)
|
1560 |
+
]
|
1561 |
+
|
1562 |
+
result_instance = {}
|
1563 |
+
for field_name, field_value in instance.items():
|
1564 |
+
if (
|
1565 |
+
field_name in self.corresponding_fields_to_intersect
|
1566 |
+
or field_name == self.field
|
1567 |
+
):
|
1568 |
+
if not isinstance(field_value, list):
|
1569 |
+
raise ValueError(
|
1570 |
+
f"Value of field '{field_name}' is not a list, IntersectCorrespondingFields can not intersect with allowed values."
|
1571 |
+
)
|
1572 |
+
if len(field_value) != num_values_in_field:
|
1573 |
+
raise ValueError(
|
1574 |
+
f"Number of elements in field '{field_name}' is not the same as the number of elements in field '{self.field}' so the IntersectCorrespondingFields can not remove corresponding values.\n"
|
1575 |
+
+ to_pretty_string(instance, keys=[self.field, field_name])
|
1576 |
+
)
|
1577 |
+
result_instance[field_name] = [
|
1578 |
+
value
|
1579 |
+
for index, value in enumerate(field_value)
|
1580 |
+
if index in indices_to_keep
|
1581 |
+
]
|
1582 |
+
else:
|
1583 |
+
result_instance[field_name] = field_value
|
1584 |
+
return result_instance
|
1585 |
+
|
1586 |
+
|
1587 |
class RemoveValues(FieldOperator):
|
1588 |
"""Removes elements in a field, which must be a list, using a given list of unallowed.
|
1589 |
|
|
|
2351 |
)
|
2352 |
|
2353 |
|
2354 |
+
class CollateInstancesByField(StreamOperator):
|
2355 |
+
"""Groups a list of instances by a specified field, aggregates specified fields into lists, and ensures consistency for all other non-aggregated fields.
|
2356 |
+
|
2357 |
+
Args:
|
2358 |
+
by_field str: the name of the field to group data by.
|
2359 |
+
aggregate_fields list(str): the field names to aggregate into lists.
|
2360 |
+
|
2361 |
+
Returns:
|
2362 |
+
A stream of instances grouped and aggregated by the specified field.
|
2363 |
+
|
2364 |
+
Raises:
|
2365 |
+
UnitxtError: If non-aggregate fields have inconsistent values.
|
2366 |
+
|
2367 |
+
Example:
|
2368 |
+
Collate the instances based on field "category" and aggregate fields "value" and "id".
|
2369 |
+
|
2370 |
+
CollateInstancesByField(by_field="category", aggregate_fields=["value", "id"])
|
2371 |
+
|
2372 |
+
given input:
|
2373 |
+
[
|
2374 |
+
{"id": 1, "category": "A", "value": 10", "flag" : True},
|
2375 |
+
{"id": 2, "category": "B", "value": 20", "flag" : False},
|
2376 |
+
{"id": 3, "category": "A", "value": 30", "flag" : True},
|
2377 |
+
{"id": 4, "category": "B", "value": 40", "flag" : False}
|
2378 |
+
]
|
2379 |
+
|
2380 |
+
the output is:
|
2381 |
+
[
|
2382 |
+
{"category": "A", "id": [1, 3], "value": [10, 30], "info": True},
|
2383 |
+
{"category": "B", "id": [2, 4], "value": [20, 40], "info": False}
|
2384 |
+
]
|
2385 |
+
|
2386 |
+
Note that the "flag" field is not aggregated, and must be the same
|
2387 |
+
in all instances in the same category, or an error is raised.
|
2388 |
+
"""
|
2389 |
+
|
2390 |
+
by_field: str = NonPositionalField(required=True)
|
2391 |
+
aggregate_fields: List[str] = NonPositionalField(required=True)
|
2392 |
+
|
2393 |
+
def prepare(self):
|
2394 |
+
super().prepare()
|
2395 |
+
|
2396 |
+
def verify(self):
|
2397 |
+
super().verify()
|
2398 |
+
if not isinstance(self.by_field, str):
|
2399 |
+
raise UnitxtError(
|
2400 |
+
f"The 'by_field' value is not a string but '{type(self.by_field)}'"
|
2401 |
+
)
|
2402 |
+
|
2403 |
+
if not isinstance(self.aggregate_fields, list):
|
2404 |
+
raise UnitxtError(
|
2405 |
+
f"The 'allowed_field_values' is not a list but '{type(self.aggregate_fields)}'"
|
2406 |
+
)
|
2407 |
+
|
2408 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None):
|
2409 |
+
grouped_data = {}
|
2410 |
+
|
2411 |
+
for instance in stream:
|
2412 |
+
if self.by_field not in instance:
|
2413 |
+
raise UnitxtError(
|
2414 |
+
f"The field '{self.by_field}' specified by CollateInstancesByField's 'by_field' argument is not found in instance."
|
2415 |
+
)
|
2416 |
+
for k in self.aggregate_fields:
|
2417 |
+
if k not in instance:
|
2418 |
+
raise UnitxtError(
|
2419 |
+
f"The field '{k}' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance."
|
2420 |
+
)
|
2421 |
+
key = instance[self.by_field]
|
2422 |
+
|
2423 |
+
if key not in grouped_data:
|
2424 |
+
grouped_data[key] = {
|
2425 |
+
k: v for k, v in instance.items() if k not in self.aggregate_fields
|
2426 |
+
}
|
2427 |
+
# Add empty lists for fields to aggregate
|
2428 |
+
for agg_field in self.aggregate_fields:
|
2429 |
+
if agg_field in instance:
|
2430 |
+
grouped_data[key][agg_field] = []
|
2431 |
+
|
2432 |
+
for k, v in instance.items():
|
2433 |
+
# Merge classification policy list across instance with same key
|
2434 |
+
if k == "data_classification_policy" and instance[k]:
|
2435 |
+
grouped_data[key][k] = sorted(set(grouped_data[key][k] + v))
|
2436 |
+
# Check consistency for all non-aggregate fields
|
2437 |
+
elif k != self.by_field and k not in self.aggregate_fields:
|
2438 |
+
if k in grouped_data[key] and grouped_data[key][k] != v:
|
2439 |
+
raise ValueError(
|
2440 |
+
f"Inconsistent value for field '{k}' in group '{key}': "
|
2441 |
+
f"'{grouped_data[key][k]}' vs '{v}'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances."
|
2442 |
+
)
|
2443 |
+
# Aggregate fields
|
2444 |
+
elif k in self.aggregate_fields:
|
2445 |
+
grouped_data[key][k].append(instance[k])
|
2446 |
+
|
2447 |
+
yield from grouped_data.values()
|
2448 |
+
|
2449 |
+
|
2450 |
class WikipediaFetcher(FieldOperator):
|
2451 |
mode: Literal["summary", "text"] = "text"
|
2452 |
_requirements_list = ["Wikipedia-API"]
|
settings_utils.py
CHANGED
@@ -149,8 +149,10 @@ if Settings.is_uninitilized():
|
|
149 |
settings.skip_artifacts_prepare_and_verify = (bool, False)
|
150 |
settings.data_classification_policy = None
|
151 |
settings.mock_inference_mode = (bool, False)
|
152 |
-
settings.disable_hf_datasets_cache = (bool,
|
153 |
-
settings.
|
|
|
|
|
154 |
settings.task_data_as_text = (bool, True)
|
155 |
settings.default_provider = "watsonx"
|
156 |
settings.default_format = None
|
|
|
149 |
settings.skip_artifacts_prepare_and_verify = (bool, False)
|
150 |
settings.data_classification_policy = None
|
151 |
settings.mock_inference_mode = (bool, False)
|
152 |
+
settings.disable_hf_datasets_cache = (bool, False)
|
153 |
+
settings.stream_hf_datasets_by_default = (bool, False)
|
154 |
+
|
155 |
+
settings.loader_cache_size = (int, 10)
|
156 |
settings.task_data_as_text = (bool, True)
|
157 |
settings.default_provider = "watsonx"
|
158 |
settings.default_format = None
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.17.
|
|
|
1 |
+
version = "1.17.2"
|