File size: 12,221 Bytes
82055e6 cc5f321 82055e6 7aa5a5e 058c80a 3c36ff5 fe70438 3c36ff5 88c61d3 6502654 b9d0035 88c61d3 3c36ff5 88c61d3 a4795aa fe70438 24df49f 88c61d3 3c36ff5 fe70438 3c36ff5 cc5f321 3c36ff5 a4795aa 3c36ff5 24df49f 3c36ff5 d08fbc6 3c36ff5 24df49f 058c80a 24df49f 058c80a 24df49f d08fbc6 24df49f d08fbc6 058c80a 88c61d3 cc5f321 fe70438 058c80a 357b16c 058c80a 88c61d3 058c80a 88c61d3 357b16c 88c61d3 357b16c 88c61d3 357b16c 058c80a d08fbc6 058c80a fe70438 cc5f321 82055e6 fe70438 82055e6 058c80a b9d0035 82055e6 7aa5a5e 058c80a d08fbc6 058c80a 7aa5a5e d08fbc6 7aa5a5e fe70438 7aa5a5e d08fbc6 7aa5a5e fe70438 058c80a d08fbc6 cc5f321 d08fbc6 cc5f321 24df49f d08fbc6 88c61d3 058c80a cc5f321 d08fbc6 fe70438 d08fbc6 88c61d3 24df49f 88c61d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
import inspect
import json
from datetime import datetime
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from .artifact import fetch_artifact
from .card import TaskCard
from .dataset_utils import get_dataset_artifact
from .error_utils import UnitxtError
from .inference import (
InferenceEngine,
LogProbInferenceEngine,
OptionSelectingByLogProbsInferenceEngine,
)
from .loaders import LoadFromDictionary
from .logging_utils import get_logger
from .metric_utils import EvaluationResults, _compute, _inference_post_process
from .operator import SourceOperator
from .schema import UNITXT_DATASET_SCHEMA, loads_instance
from .settings_utils import get_constants, get_settings
from .standard import DatasetRecipe
from .task import Task
logger = get_logger()
constants = get_constants()
settings = get_settings()
def load(source: Union[SourceOperator, str]):
assert isinstance(
source, (SourceOperator, str)
), "source must be a SourceOperator or a string"
if isinstance(source, str):
source, _ = fetch_artifact(source)
return source().to_dataset()
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
dataset_query = dataset_query.replace("sys_prompt", "instruction")
try:
dataset_stream, _ = fetch_artifact(dataset_query)
except:
dataset_stream = get_dataset_artifact(dataset_query)
return dataset_stream
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
for param in dataset_params.keys():
assert param in recipe_attributes, (
f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
)
return DatasetRecipe(**dataset_params)
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
if dataset_query and dataset_args:
raise ValueError(
"Cannot provide 'dataset_query' and key-worded arguments at the same time. "
"If you want to load dataset from a card in local catalog, use query only. "
"Otherwise, use key-worded arguments only to specify properties of dataset."
)
if dataset_query:
if not isinstance(dataset_query, str):
raise ValueError(
f"If specified, 'dataset_query' must be a string, however, "
f"'{dataset_query}' was provided instead, which is of type "
f"'{type(dataset_query)}'."
)
if not dataset_query and not dataset_args:
raise ValueError(
"Either 'dataset_query' or key-worded arguments must be provided."
)
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
if isinstance(dataset_query, DatasetRecipe):
return dataset_query
_verify_dataset_args(dataset_query, kwargs)
if dataset_query:
recipe = _get_recipe_from_query(dataset_query)
if kwargs:
recipe = _get_recipe_from_dict(kwargs)
return recipe
def create_dataset(
task: Union[str, Task],
test_set: List[Dict[Any, Any]],
train_set: Optional[List[Dict[Any, Any]]] = None,
validation_set: Optional[List[Dict[Any, Any]]] = None,
split: Optional[str] = None,
**kwargs,
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
"""Creates dataset from input data based on a specific task.
Args:
task: The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
test_set : required list of instances
train_set : optional train_set
validation_set: optional validation set
split: optional one split to choose
**kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
Returns:
DatasetDict
Example:
template = Template(...)
dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
"""
data = {"test": test_set}
if train_set is not None:
data["train"] = train_set
if validation_set is not None:
data["validation"] = validation_set
task, _ = fetch_artifact(task)
if "template" not in kwargs and task.default_template is None:
raise Exception(
f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
)
card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
return load_dataset(card=card, split=split, **kwargs)
def load_dataset(
dataset_query: Optional[str] = None,
split: Optional[str] = None,
streaming: bool = False,
disable_cache: Optional[bool] = None,
**kwargs,
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
"""Loads dataset.
If the 'dataset_query' argument is provided, then dataset is loaded from a card
in local catalog based on parameters specified in the query.
Alternatively, dataset is loaded from a provided card based on explicitly
given parameters.
Args:
dataset_query (str, optional):
A string query which specifies a dataset to load from
local catalog or name of specific recipe or benchmark in the catalog. For
example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
streaming (bool, False):
When True yields the data as Unitxt streams dictionary
split (str, optional):
The split of the data to load
disable_cache (str, optional):
Disable caching process of the data
**kwargs:
Arguments used to load dataset from provided card, which is not present in local catalog.
Returns:
DatasetDict
:Example:
.. code-block:: python
dataset = load_dataset(
dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
) # card and template must be present in local catalog
# or built programmatically
card = TaskCard(...)
template = Template(...)
loader_limit = 10
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
"""
recipe = load_recipe(dataset_query, **kwargs)
stream = recipe()
if split is not None:
stream = stream[split]
if disable_cache is None:
disable_cache = settings.disable_hf_datasets_cache
if streaming:
dataset = stream.to_iterable_dataset(
features=UNITXT_DATASET_SCHEMA,
).map(loads_instance, batched=True)
else:
dataset = stream.to_dataset(
features=UNITXT_DATASET_SCHEMA, disable_cache=disable_cache
).with_transform(loads_instance)
frame = inspect.currentframe()
args, _, _, values = inspect.getargvalues(frame)
all_kwargs = {key: values[key] for key in args if key != "kwargs"}
all_kwargs.update(kwargs)
metadata = fill_metadata(**all_kwargs)
if isinstance(dataset, dict):
for ds in dataset.values():
ds.info.description = metadata.copy()
else:
dataset.info.description = metadata
return dataset
def fill_metadata(**kwargs):
metadata = kwargs.copy()
metadata["unitxt_version"] = get_constants().version
metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
return metadata
def evaluate(
predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
) -> EvaluationResults:
if dataset is None and data is None:
raise UnitxtError(message="Specify 'dataset' in evaluate")
if data is not None:
dataset = data # for backward compatibility
evaluation_result = _compute(predictions=predictions, references=dataset)
if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
evaluation_result.metadata["dataset"] = dataset.info.description
if hasattr(predictions, "metadata"):
evaluation_result.metadata["predictions"] = predictions.metadata
evaluation_result.metadata["creation_time"] = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S.%f"
)[:-3]
return evaluation_result
def post_process(predictions, data) -> List[Dict[str, Any]]:
return _inference_post_process(predictions=predictions, references=data)
@lru_cache
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
return load_recipe(dataset_query, **kwargs).produce
def produce(
instance_or_instances, dataset_query: Optional[str] = None, **kwargs
) -> Union[Dataset, Dict[str, Any]]:
is_list = isinstance(instance_or_instances, list)
if not is_list:
instance_or_instances = [instance_or_instances]
result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
if not is_list:
return result[0]
return Dataset.from_list(result).with_transform(loads_instance)
def infer(
instance_or_instances,
engine: InferenceEngine,
dataset_query: Optional[str] = None,
return_data: bool = False,
return_log_probs: bool = False,
return_meta_data: bool = False,
previous_messages: Optional[List[Dict[str, str]]] = None,
**kwargs,
):
dataset = produce(instance_or_instances, dataset_query, **kwargs)
if previous_messages is not None:
def add_previous_messages(example, index):
example["source"] = previous_messages[index] + example["source"]
return example
dataset = dataset.map(add_previous_messages, with_indices=True)
engine, _ = fetch_artifact(engine)
if return_log_probs:
if not isinstance(engine, LogProbInferenceEngine):
raise NotImplementedError(
f"Error in infer: return_log_probs set to True but supplied engine "
f"{engine.__class__.__name__} does not support logprobs."
)
infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
raw_predictions = (
[output.prediction for output in infer_outputs]
if return_meta_data
else infer_outputs
)
raw_predictions = [
json.dumps(raw_prediction) for raw_prediction in raw_predictions
]
else:
infer_outputs = engine.infer(dataset, return_meta_data)
raw_predictions = (
[output.prediction for output in infer_outputs]
if return_meta_data
else infer_outputs
)
predictions = post_process(raw_predictions, dataset)
if return_data:
if return_meta_data:
infer_output_list = [
infer_output.__dict__ for infer_output in infer_outputs
]
for infer_output in infer_output_list:
del infer_output["prediction"]
dataset = dataset.add_column("infer_meta_data", infer_output_list)
dataset = dataset.add_column("prediction", predictions)
return dataset.add_column("raw_prediction", raw_predictions)
return predictions
def select(
instance_or_instances,
engine: OptionSelectingByLogProbsInferenceEngine,
dataset_query: Optional[str] = None,
return_data: bool = False,
previous_messages: Optional[List[Dict[str, str]]] = None,
**kwargs,
):
dataset = produce(instance_or_instances, dataset_query, **kwargs)
if previous_messages is not None:
def add_previous_messages(example, index):
example["source"] = previous_messages[index] + example["source"]
return example
dataset = dataset.map(add_previous_messages, with_indices=True)
engine, _ = fetch_artifact(engine)
predictions = engine.select(dataset)
# predictions = post_process(raw_predictions, dataset)
if return_data:
return dataset.add_column("prediction", predictions)
return predictions
|