File size: 4,325 Bytes
7aa5a5e
058c80a
3c36ff5
 
 
 
6502654
3c36ff5
058c80a
a4795aa
058c80a
3c36ff5
 
 
 
a4795aa
3c36ff5
a4795aa
 
3c36ff5
 
 
 
 
058c80a
3c36ff5
 
 
 
 
058c80a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c36ff5
 
7aa5a5e
 
058c80a
 
 
 
7aa5a5e
 
 
 
 
 
 
 
 
 
 
 
 
058c80a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union

from datasets import DatasetDict

from .artifact import fetch_artifact
from .dataset_utils import get_dataset_artifact
from .logging_utils import get_logger
from .metric_utils import _compute, _post_process
from .operator import SourceOperator
from .standard import StandardRecipe

logger = get_logger()


def load(source: Union[SourceOperator, str]) -> DatasetDict:
    assert isinstance(
        source, (SourceOperator, str)
    ), "source must be a SourceOperator or a string"
    if isinstance(source, str):
        source, _ = fetch_artifact(source)
    return source().to_dataset()


def _load_dataset_from_query(dataset_query: str) -> DatasetDict:
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
    dataset_stream = get_dataset_artifact(dataset_query)
    return dataset_stream().to_dataset()


def _load_dataset_from_dict(dataset_params: Dict[str, Any]) -> DatasetDict:
    recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
    for param in dataset_params.keys():
        assert param in recipe_attributes, (
            f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
        )
    recipe = StandardRecipe(**dataset_params)
    return recipe().to_dataset()


def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
    """Loads dataset.

    If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
    catalog based on parameters specified in the query.
    Alternatively, dataset is loaded from a provided card based on explicitly given parameters.

    Args:
        dataset_query (str, optional): A string query which specifies dataset to load from local catalog.
            For example:
            "card=cards.wnli,template=templates.classification.multi_class.relation.default".
        **kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.

    Returns:
        DatasetDict

    Examples:
        dataset = load_dataset(
            dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
        )  # card must be present in local catalog

        card = TaskCard(...)
        template = Template(...)
        loader_limit = 10
        dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
    """
    if dataset_query and kwargs:
        raise ValueError(
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
            "If you want to load dataset from a card in local catalog, use query only. "
            "Otherwise, use key-worded arguments only to specify properties of dataset."
        )

    if dataset_query:
        if not isinstance(dataset_query, str):
            raise ValueError(
                f"If specified, 'dataset_query' must be a string, however, "
                f"'{dataset_query}' was provided instead, which is of type "
                f"'{type(dataset_query)}'."
            )
        return _load_dataset_from_query(dataset_query)

    if kwargs:
        return _load_dataset_from_dict(kwargs)

    raise ValueError("Either 'dataset_query' or key-worded arguments must be provided.")


def evaluate(predictions, data) -> List[Dict[str, Any]]:
    return _compute(predictions=predictions, references=data)


def post_process(predictions, data) -> List[Dict[str, Any]]:
    return _post_process(predictions=predictions, references=data)


@lru_cache
def _get_produce_with_cache(recipe_query):
    return get_dataset_artifact(recipe_query).produce


def produce(instance_or_instances, recipe_query):
    is_list = isinstance(instance_or_instances, list)
    if not is_list:
        instance_or_instances = [instance_or_instances]
    result = _get_produce_with_cache(recipe_query)(instance_or_instances)
    if not is_list:
        result = result[0]
    return result


def infer(instance_or_instances, recipe, engine):
    dataset = produce(instance_or_instances, recipe)
    engine, _ = fetch_artifact(engine)
    predictions = engine.infer(dataset)
    return post_process(predictions, dataset)