File size: 4,325 Bytes
7aa5a5e 058c80a 3c36ff5 6502654 3c36ff5 058c80a a4795aa 058c80a 3c36ff5 a4795aa 3c36ff5 a4795aa 3c36ff5 058c80a 3c36ff5 058c80a 3c36ff5 7aa5a5e 058c80a 7aa5a5e 058c80a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union
from datasets import DatasetDict
from .artifact import fetch_artifact
from .dataset_utils import get_dataset_artifact
from .logging_utils import get_logger
from .metric_utils import _compute, _post_process
from .operator import SourceOperator
from .standard import StandardRecipe
logger = get_logger()
def load(source: Union[SourceOperator, str]) -> DatasetDict:
assert isinstance(
source, (SourceOperator, str)
), "source must be a SourceOperator or a string"
if isinstance(source, str):
source, _ = fetch_artifact(source)
return source().to_dataset()
def _load_dataset_from_query(dataset_query: str) -> DatasetDict:
dataset_query = dataset_query.replace("sys_prompt", "instruction")
dataset_stream = get_dataset_artifact(dataset_query)
return dataset_stream().to_dataset()
def _load_dataset_from_dict(dataset_params: Dict[str, Any]) -> DatasetDict:
recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
for param in dataset_params.keys():
assert param in recipe_attributes, (
f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
)
recipe = StandardRecipe(**dataset_params)
return recipe().to_dataset()
def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
"""Loads dataset.
If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
catalog based on parameters specified in the query.
Alternatively, dataset is loaded from a provided card based on explicitly given parameters.
Args:
dataset_query (str, optional): A string query which specifies dataset to load from local catalog.
For example:
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
Returns:
DatasetDict
Examples:
dataset = load_dataset(
dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
) # card must be present in local catalog
card = TaskCard(...)
template = Template(...)
loader_limit = 10
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
"""
if dataset_query and kwargs:
raise ValueError(
"Cannot provide 'dataset_query' and key-worded arguments at the same time. "
"If you want to load dataset from a card in local catalog, use query only. "
"Otherwise, use key-worded arguments only to specify properties of dataset."
)
if dataset_query:
if not isinstance(dataset_query, str):
raise ValueError(
f"If specified, 'dataset_query' must be a string, however, "
f"'{dataset_query}' was provided instead, which is of type "
f"'{type(dataset_query)}'."
)
return _load_dataset_from_query(dataset_query)
if kwargs:
return _load_dataset_from_dict(kwargs)
raise ValueError("Either 'dataset_query' or key-worded arguments must be provided.")
def evaluate(predictions, data) -> List[Dict[str, Any]]:
return _compute(predictions=predictions, references=data)
def post_process(predictions, data) -> List[Dict[str, Any]]:
return _post_process(predictions=predictions, references=data)
@lru_cache
def _get_produce_with_cache(recipe_query):
return get_dataset_artifact(recipe_query).produce
def produce(instance_or_instances, recipe_query):
is_list = isinstance(instance_or_instances, list)
if not is_list:
instance_or_instances = [instance_or_instances]
result = _get_produce_with_cache(recipe_query)(instance_or_instances)
if not is_list:
result = result[0]
return result
def infer(instance_or_instances, recipe, engine):
dataset = produce(instance_or_instances, recipe)
engine, _ = fetch_artifact(engine)
predictions = engine.infer(dataset)
return post_process(predictions, dataset)
|