|
import json |
|
from functools import lru_cache |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict |
|
|
|
from .artifact import fetch_artifact |
|
from .card import TaskCard |
|
from .dataset_utils import get_dataset_artifact |
|
from .error_utils import UnitxtError |
|
from .inference import ( |
|
InferenceEngine, |
|
LogProbInferenceEngine, |
|
OptionSelectingByLogProbsInferenceEngine, |
|
) |
|
from .loaders import LoadFromDictionary |
|
from .logging_utils import get_logger |
|
from .metric_utils import EvaluationResults, _compute, _inference_post_process |
|
from .operator import SourceOperator |
|
from .schema import UNITXT_DATASET_SCHEMA, loads_instance |
|
from .settings_utils import get_constants, get_settings |
|
from .standard import DatasetRecipe |
|
from .task import Task |
|
|
|
logger = get_logger() |
|
constants = get_constants() |
|
settings = get_settings() |
|
|
|
|
|
def load(source: Union[SourceOperator, str]): |
|
assert isinstance( |
|
source, (SourceOperator, str) |
|
), "source must be a SourceOperator or a string" |
|
if isinstance(source, str): |
|
source, _ = fetch_artifact(source) |
|
return source().to_dataset() |
|
|
|
|
|
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe: |
|
dataset_query = dataset_query.replace("sys_prompt", "instruction") |
|
try: |
|
dataset_stream, _ = fetch_artifact(dataset_query) |
|
except: |
|
dataset_stream = get_dataset_artifact(dataset_query) |
|
return dataset_stream |
|
|
|
|
|
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe: |
|
recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys()) |
|
for param in dataset_params.keys(): |
|
assert param in recipe_attributes, ( |
|
f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. " |
|
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'." |
|
) |
|
return DatasetRecipe(**dataset_params) |
|
|
|
|
|
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None): |
|
if dataset_query and dataset_args: |
|
raise ValueError( |
|
"Cannot provide 'dataset_query' and key-worded arguments at the same time. " |
|
"If you want to load dataset from a card in local catalog, use query only. " |
|
"Otherwise, use key-worded arguments only to specify properties of dataset." |
|
) |
|
|
|
if dataset_query: |
|
if not isinstance(dataset_query, str): |
|
raise ValueError( |
|
f"If specified, 'dataset_query' must be a string, however, " |
|
f"'{dataset_query}' was provided instead, which is of type " |
|
f"'{type(dataset_query)}'." |
|
) |
|
|
|
if not dataset_query and not dataset_args: |
|
raise ValueError( |
|
"Either 'dataset_query' or key-worded arguments must be provided." |
|
) |
|
|
|
|
|
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe: |
|
if isinstance(dataset_query, DatasetRecipe): |
|
return dataset_query |
|
|
|
_verify_dataset_args(dataset_query, kwargs) |
|
|
|
if dataset_query: |
|
recipe = _get_recipe_from_query(dataset_query) |
|
|
|
if kwargs: |
|
recipe = _get_recipe_from_dict(kwargs) |
|
|
|
return recipe |
|
|
|
|
|
def create_dataset( |
|
task: Union[str, Task], |
|
test_set: List[Dict[Any, Any]], |
|
train_set: Optional[List[Dict[Any, Any]]] = None, |
|
validation_set: Optional[List[Dict[Any, Any]]] = None, |
|
split: Optional[str] = None, |
|
**kwargs, |
|
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]: |
|
"""Creates dataset from input data based on a specific task. |
|
|
|
Args: |
|
task: The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html) |
|
test_set : required list of instances |
|
train_set : optional train_set |
|
validation_set: optional validation set |
|
split: optional one split to choose |
|
**kwargs: Arguments used to load dataset from provided datasets (see load_dataset()) |
|
|
|
Returns: |
|
DatasetDict |
|
|
|
Example: |
|
template = Template(...) |
|
dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi") |
|
""" |
|
data = {"test": test_set} |
|
if train_set is not None: |
|
data["train"] = train_set |
|
if validation_set is not None: |
|
data["validation"] = validation_set |
|
task, _ = fetch_artifact(task) |
|
|
|
if "template" not in kwargs and task.default_template is None: |
|
raise Exception( |
|
f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field." |
|
) |
|
|
|
card = TaskCard(loader=LoadFromDictionary(data=data), task=task) |
|
return load_dataset(card=card, split=split, **kwargs) |
|
|
|
|
|
def load_dataset( |
|
dataset_query: Optional[str] = None, |
|
split: Optional[str] = None, |
|
streaming: bool = False, |
|
disable_cache: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]: |
|
"""Loads dataset. |
|
|
|
If the 'dataset_query' argument is provided, then dataset is loaded from a card |
|
in local catalog based on parameters specified in the query. |
|
|
|
Alternatively, dataset is loaded from a provided card based on explicitly |
|
given parameters. |
|
|
|
Args: |
|
dataset_query (str, optional): |
|
A string query which specifies a dataset to load from |
|
local catalog or name of specific recipe or benchmark in the catalog. For |
|
example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``. |
|
streaming (bool, False): |
|
When True yields the data as Unitxt streams dictionary |
|
split (str, optional): |
|
The split of the data to load |
|
disable_cache (str, optional): |
|
Disable caching process of the data |
|
**kwargs: |
|
Arguments used to load dataset from provided card, which is not present in local catalog. |
|
|
|
Returns: |
|
DatasetDict |
|
|
|
:Example: |
|
|
|
.. code-block:: python |
|
|
|
dataset = load_dataset( |
|
dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5" |
|
) # card and template must be present in local catalog |
|
|
|
# or built programmatically |
|
card = TaskCard(...) |
|
template = Template(...) |
|
loader_limit = 10 |
|
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit) |
|
|
|
""" |
|
recipe = load_recipe(dataset_query, **kwargs) |
|
|
|
stream = recipe() |
|
if split is not None: |
|
stream = stream[split] |
|
|
|
if disable_cache is None: |
|
disable_cache = settings.disable_hf_datasets_cache |
|
|
|
if streaming: |
|
return stream.to_iterable_dataset( |
|
features=UNITXT_DATASET_SCHEMA, |
|
).map(loads_instance, batched=True) |
|
|
|
return stream.to_dataset( |
|
features=UNITXT_DATASET_SCHEMA, disable_cache=disable_cache |
|
).with_transform(loads_instance) |
|
|
|
|
|
def evaluate( |
|
predictions, dataset: Union[Dataset, IterableDataset] = None, data=None |
|
) -> EvaluationResults: |
|
if dataset is None and data is None: |
|
raise UnitxtError(message="Specify 'dataset' in evaluate") |
|
if data is not None: |
|
dataset = data |
|
return _compute(predictions=predictions, references=dataset) |
|
|
|
|
|
def post_process(predictions, data) -> List[Dict[str, Any]]: |
|
return _inference_post_process(predictions=predictions, references=data) |
|
|
|
|
|
@lru_cache |
|
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs): |
|
return load_recipe(dataset_query, **kwargs).produce |
|
|
|
|
|
def produce( |
|
instance_or_instances, dataset_query: Optional[str] = None, **kwargs |
|
) -> Union[Dataset, Dict[str, Any]]: |
|
is_list = isinstance(instance_or_instances, list) |
|
if not is_list: |
|
instance_or_instances = [instance_or_instances] |
|
result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances) |
|
if not is_list: |
|
return result[0] |
|
return Dataset.from_list(result).with_transform(loads_instance) |
|
|
|
|
|
def infer( |
|
instance_or_instances, |
|
engine: InferenceEngine, |
|
dataset_query: Optional[str] = None, |
|
return_data: bool = False, |
|
return_log_probs: bool = False, |
|
return_meta_data: bool = False, |
|
previous_messages: Optional[List[Dict[str, str]]] = None, |
|
**kwargs, |
|
): |
|
dataset = produce(instance_or_instances, dataset_query, **kwargs) |
|
if previous_messages is not None: |
|
|
|
def add_previous_messages(example, index): |
|
example["source"] = previous_messages[index] + example["source"] |
|
return example |
|
|
|
dataset = dataset.map(add_previous_messages, with_indices=True) |
|
engine, _ = fetch_artifact(engine) |
|
if return_log_probs: |
|
if not isinstance(engine, LogProbInferenceEngine): |
|
raise NotImplementedError( |
|
f"Error in infer: return_log_probs set to True but supplied engine " |
|
f"{engine.__class__.__name__} does not support logprobs." |
|
) |
|
infer_outputs = engine.infer_log_probs(dataset, return_meta_data) |
|
raw_predictions = ( |
|
[output.prediction for output in infer_outputs] |
|
if return_meta_data |
|
else infer_outputs |
|
) |
|
raw_predictions = [ |
|
json.dumps(raw_prediction) for raw_prediction in raw_predictions |
|
] |
|
else: |
|
infer_outputs = engine.infer(dataset, return_meta_data) |
|
raw_predictions = ( |
|
[output.prediction for output in infer_outputs] |
|
if return_meta_data |
|
else infer_outputs |
|
) |
|
predictions = post_process(raw_predictions, dataset) |
|
if return_data: |
|
if return_meta_data: |
|
infer_output_list = [ |
|
infer_output.__dict__ for infer_output in infer_outputs |
|
] |
|
for infer_output in infer_output_list: |
|
del infer_output["prediction"] |
|
dataset = dataset.add_column("infer_meta_data", infer_output_list) |
|
dataset = dataset.add_column("prediction", predictions) |
|
return dataset.add_column("raw_prediction", raw_predictions) |
|
return predictions |
|
|
|
|
|
def select( |
|
instance_or_instances, |
|
engine: OptionSelectingByLogProbsInferenceEngine, |
|
dataset_query: Optional[str] = None, |
|
return_data: bool = False, |
|
previous_messages: Optional[List[Dict[str, str]]] = None, |
|
**kwargs, |
|
): |
|
dataset = produce(instance_or_instances, dataset_query, **kwargs) |
|
if previous_messages is not None: |
|
|
|
def add_previous_messages(example, index): |
|
example["source"] = previous_messages[index] + example["source"] |
|
return example |
|
|
|
dataset = dataset.map(add_previous_messages, with_indices=True) |
|
engine, _ = fetch_artifact(engine) |
|
predictions = engine.select(dataset) |
|
|
|
if return_data: |
|
return dataset.add_column("prediction", predictions) |
|
return predictions |
|
|