File size: 12,221 Bytes
82055e6
cc5f321
82055e6
7aa5a5e
058c80a
3c36ff5
fe70438
 
3c36ff5
88c61d3
6502654
b9d0035
88c61d3
 
 
 
 
 
3c36ff5
88c61d3
a4795aa
fe70438
 
24df49f
88c61d3
3c36ff5
 
fe70438
 
3c36ff5
 
cc5f321
3c36ff5
a4795aa
 
3c36ff5
 
 
 
 
24df49f
3c36ff5
d08fbc6
 
 
 
 
3c36ff5
 
24df49f
 
058c80a
 
24df49f
058c80a
 
24df49f
d08fbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24df49f
 
d08fbc6
 
 
 
 
 
 
 
 
 
 
058c80a
 
88c61d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc5f321
fe70438
 
 
 
 
 
058c80a
 
357b16c
 
 
 
 
058c80a
 
88c61d3
 
 
 
 
 
 
 
 
 
 
 
058c80a
 
 
 
88c61d3
 
357b16c
 
 
 
88c61d3
357b16c
88c61d3
357b16c
 
 
 
058c80a
 
d08fbc6
058c80a
fe70438
 
 
 
 
 
 
cc5f321
82055e6
fe70438
 
82055e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
058c80a
 
b9d0035
 
 
 
 
 
 
82055e6
 
 
 
 
 
 
 
 
7aa5a5e
 
058c80a
d08fbc6
058c80a
 
7aa5a5e
d08fbc6
 
7aa5a5e
 
fe70438
 
 
7aa5a5e
 
 
d08fbc6
7aa5a5e
fe70438
 
058c80a
 
d08fbc6
 
cc5f321
d08fbc6
cc5f321
 
 
24df49f
d08fbc6
 
 
88c61d3
 
 
 
 
 
 
058c80a
cc5f321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d08fbc6
 
fe70438
 
 
 
 
 
 
 
 
d08fbc6
88c61d3
 
 
 
 
 
 
24df49f
88c61d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import inspect
import json
from datetime import datetime
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union

from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict

from .artifact import fetch_artifact
from .card import TaskCard
from .dataset_utils import get_dataset_artifact
from .error_utils import UnitxtError
from .inference import (
    InferenceEngine,
    LogProbInferenceEngine,
    OptionSelectingByLogProbsInferenceEngine,
)
from .loaders import LoadFromDictionary
from .logging_utils import get_logger
from .metric_utils import EvaluationResults, _compute, _inference_post_process
from .operator import SourceOperator
from .schema import UNITXT_DATASET_SCHEMA, loads_instance
from .settings_utils import get_constants, get_settings
from .standard import DatasetRecipe
from .task import Task

logger = get_logger()
constants = get_constants()
settings = get_settings()


def load(source: Union[SourceOperator, str]):
    assert isinstance(
        source, (SourceOperator, str)
    ), "source must be a SourceOperator or a string"
    if isinstance(source, str):
        source, _ = fetch_artifact(source)
    return source().to_dataset()


def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
    try:
        dataset_stream, _ = fetch_artifact(dataset_query)
    except:
        dataset_stream = get_dataset_artifact(dataset_query)
    return dataset_stream


def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
    for param in dataset_params.keys():
        assert param in recipe_attributes, (
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
        )
    return DatasetRecipe(**dataset_params)


def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
    if dataset_query and dataset_args:
        raise ValueError(
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
            "If you want to load dataset from a card in local catalog, use query only. "
            "Otherwise, use key-worded arguments only to specify properties of dataset."
        )

    if dataset_query:
        if not isinstance(dataset_query, str):
            raise ValueError(
                f"If specified, 'dataset_query' must be a string, however, "
                f"'{dataset_query}' was provided instead, which is of type "
                f"'{type(dataset_query)}'."
            )

    if not dataset_query and not dataset_args:
        raise ValueError(
            "Either 'dataset_query' or key-worded arguments must be provided."
        )


def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
    if isinstance(dataset_query, DatasetRecipe):
        return dataset_query

    _verify_dataset_args(dataset_query, kwargs)

    if dataset_query:
        recipe = _get_recipe_from_query(dataset_query)

    if kwargs:
        recipe = _get_recipe_from_dict(kwargs)

    return recipe


def create_dataset(
    task: Union[str, Task],
    test_set: List[Dict[Any, Any]],
    train_set: Optional[List[Dict[Any, Any]]] = None,
    validation_set: Optional[List[Dict[Any, Any]]] = None,
    split: Optional[str] = None,
    **kwargs,
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
    """Creates dataset from input data based on a specific task.

    Args:
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
        test_set : required list of instances
        train_set : optional train_set
        validation_set: optional validation set
        split: optional one split to choose
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())

    Returns:
        DatasetDict

    Example:
        template = Template(...)
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
    """
    data = {"test": test_set}
    if train_set is not None:
        data["train"] = train_set
    if validation_set is not None:
        data["validation"] = validation_set
    task, _ = fetch_artifact(task)

    if "template" not in kwargs and task.default_template is None:
        raise Exception(
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
        )

    card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
    return load_dataset(card=card, split=split, **kwargs)


def load_dataset(
    dataset_query: Optional[str] = None,
    split: Optional[str] = None,
    streaming: bool = False,
    disable_cache: Optional[bool] = None,
    **kwargs,
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
    """Loads dataset.

    If the 'dataset_query' argument is provided, then dataset is loaded from a card
    in local catalog based on parameters specified in the query.

    Alternatively, dataset is loaded from a provided card based on explicitly
    given parameters.

    Args:
        dataset_query (str, optional):
            A string query which specifies a dataset to load from
            local catalog or name of specific recipe or benchmark in the catalog. For
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
        streaming (bool, False):
            When True yields the data as Unitxt streams dictionary
        split (str, optional):
            The split of the data to load
        disable_cache (str, optional):
            Disable caching process of the data
        **kwargs:
            Arguments used to load dataset from provided card, which is not present in local catalog.

    Returns:
        DatasetDict

    :Example:

        .. code-block:: python

            dataset = load_dataset(
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
            )  # card and template must be present in local catalog

            # or built programmatically
            card = TaskCard(...)
            template = Template(...)
            loader_limit = 10
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)

    """
    recipe = load_recipe(dataset_query, **kwargs)

    stream = recipe()
    if split is not None:
        stream = stream[split]

    if disable_cache is None:
        disable_cache = settings.disable_hf_datasets_cache

    if streaming:
        dataset = stream.to_iterable_dataset(
            features=UNITXT_DATASET_SCHEMA,
        ).map(loads_instance, batched=True)
    else:
        dataset = stream.to_dataset(
            features=UNITXT_DATASET_SCHEMA, disable_cache=disable_cache
        ).with_transform(loads_instance)

    frame = inspect.currentframe()
    args, _, _, values = inspect.getargvalues(frame)
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
    all_kwargs.update(kwargs)
    metadata = fill_metadata(**all_kwargs)
    if isinstance(dataset, dict):
        for ds in dataset.values():
            ds.info.description = metadata.copy()
    else:
        dataset.info.description = metadata
    return dataset


def fill_metadata(**kwargs):
    metadata = kwargs.copy()
    metadata["unitxt_version"] = get_constants().version
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
    return metadata


def evaluate(
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
) -> EvaluationResults:
    if dataset is None and data is None:
        raise UnitxtError(message="Specify 'dataset' in evaluate")
    if data is not None:
        dataset = data  # for backward compatibility
    evaluation_result = _compute(predictions=predictions, references=dataset)
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
        evaluation_result.metadata["dataset"] = dataset.info.description
    if hasattr(predictions, "metadata"):
        evaluation_result.metadata["predictions"] = predictions.metadata
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
        "%Y-%m-%d %H:%M:%S.%f"
    )[:-3]
    return evaluation_result


def post_process(predictions, data) -> List[Dict[str, Any]]:
    return _inference_post_process(predictions=predictions, references=data)


@lru_cache
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
    return load_recipe(dataset_query, **kwargs).produce


def produce(
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
) -> Union[Dataset, Dict[str, Any]]:
    is_list = isinstance(instance_or_instances, list)
    if not is_list:
        instance_or_instances = [instance_or_instances]
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
    if not is_list:
        return result[0]
    return Dataset.from_list(result).with_transform(loads_instance)


def infer(
    instance_or_instances,
    engine: InferenceEngine,
    dataset_query: Optional[str] = None,
    return_data: bool = False,
    return_log_probs: bool = False,
    return_meta_data: bool = False,
    previous_messages: Optional[List[Dict[str, str]]] = None,
    **kwargs,
):
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
    if previous_messages is not None:

        def add_previous_messages(example, index):
            example["source"] = previous_messages[index] + example["source"]
            return example

        dataset = dataset.map(add_previous_messages, with_indices=True)
    engine, _ = fetch_artifact(engine)
    if return_log_probs:
        if not isinstance(engine, LogProbInferenceEngine):
            raise NotImplementedError(
                f"Error in infer: return_log_probs set to True but supplied engine "
                f"{engine.__class__.__name__} does not support logprobs."
            )
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
        raw_predictions = (
            [output.prediction for output in infer_outputs]
            if return_meta_data
            else infer_outputs
        )
        raw_predictions = [
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
        ]
    else:
        infer_outputs = engine.infer(dataset, return_meta_data)
        raw_predictions = (
            [output.prediction for output in infer_outputs]
            if return_meta_data
            else infer_outputs
        )
    predictions = post_process(raw_predictions, dataset)
    if return_data:
        if return_meta_data:
            infer_output_list = [
                infer_output.__dict__ for infer_output in infer_outputs
            ]
            for infer_output in infer_output_list:
                del infer_output["prediction"]
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
        dataset = dataset.add_column("prediction", predictions)
        return dataset.add_column("raw_prediction", raw_predictions)
    return predictions


def select(
    instance_or_instances,
    engine: OptionSelectingByLogProbsInferenceEngine,
    dataset_query: Optional[str] = None,
    return_data: bool = False,
    previous_messages: Optional[List[Dict[str, str]]] = None,
    **kwargs,
):
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
    if previous_messages is not None:

        def add_previous_messages(example, index):
            example["source"] = previous_messages[index] + example["source"]
            return example

        dataset = dataset.map(add_previous_messages, with_indices=True)
    engine, _ = fetch_artifact(engine)
    predictions = engine.select(dataset)
    # predictions = post_process(raw_predictions, dataset)
    if return_data:
        return dataset.add_column("prediction", predictions)
    return predictions