File size: 20,868 Bytes
9d5b4c0
c6e9c8c
fe70438
d443ad5
c6e9c8c
9d5b4c0
8320ba9
fe70438
1e05e68
 
2c69fb8
7cdc7d0
c6e9c8c
fe70438
7cdc7d0
fe70438
9d5b4c0
2c69fb8
cd9d84b
d08fbc6
7cdc7d0
d443ad5
fe70438
c6e9c8c
d08fbc6
fe70438
1e05e68
 
c6e9c8c
eee0bf8
 
 
 
 
 
9d5b4c0
d08fbc6
 
7cdc7d0
cd9d84b
fe70438
7cdc7d0
9d5b4c0
 
 
8320ba9
 
c6e9c8c
d08fbc6
 
eee0bf8
 
5bbb99c
 
 
 
eee0bf8
 
 
5bbb99c
c6e9c8c
9d5b4c0
cd9d84b
c6e9c8c
 
 
 
 
 
fe70438
eee0bf8
c6e9c8c
 
fe70438
 
 
1e05e68
 
 
9d5b4c0
 
 
 
 
 
5bbb99c
 
d08fbc6
 
 
 
 
 
 
 
 
 
 
9d5b4c0
5bbb99c
 
a024d9a
5bbb99c
9d5b4c0
5bbb99c
9d5b4c0
eee0bf8
 
 
1e05e68
eee0bf8
 
 
 
 
1e05e68
eee0bf8
 
 
 
 
 
1e05e68
eee0bf8
 
 
 
 
 
1e05e68
5bbb99c
f6ebc4f
 
 
 
 
 
 
 
 
 
5bbb99c
9245edf
 
 
 
9d5b4c0
 
fe70438
9d5b4c0
 
 
 
 
 
 
 
7cdc7d0
 
 
 
 
67f4e71
 
 
2c69fb8
67f4e71
 
 
2c69fb8
67f4e71
 
 
2c69fb8
67f4e71
9d5b4c0
 
f6ebc4f
9d5b4c0
f6ebc4f
 
2c69fb8
d08fbc6
 
0a1b314
d08fbc6
 
0a1b314
d08fbc6
 
 
 
 
 
0a1b314
9d5b4c0
 
2c69fb8
0a1b314
2c69fb8
c6e9c8c
2c69fb8
 
 
 
b462f85
9d5b4c0
2c69fb8
 
 
 
 
 
 
 
b462f85
2c69fb8
 
 
 
 
 
 
 
 
b462f85
2c69fb8
 
 
 
fe70438
2c69fb8
 
d08fbc6
 
2c69fb8
 
9d5b4c0
fe70438
 
 
 
 
2c69fb8
 
9d5b4c0
 
 
 
 
 
 
 
2c69fb8
 
 
cc5f321
 
 
 
 
 
2c69fb8
d08fbc6
2c69fb8
fe70438
 
 
d08fbc6
fe70438
 
 
 
 
 
d08fbc6
59be457
 
d08fbc6
 
2c69fb8
d08fbc6
2c69fb8
d08fbc6
 
 
 
 
 
 
 
 
 
 
 
2c69fb8
 
058c80a
a024d9a
9d5b4c0
 
a024d9a
2c69fb8
 
eee0bf8
d08fbc6
 
c6e9c8c
d08fbc6
c6e9c8c
d443ad5
fe70438
 
 
 
 
 
 
 
 
 
 
 
 
 
eee0bf8
9d5b4c0
2c69fb8
eee0bf8
c6e9c8c
 
 
cd9d84b
c6e9c8c
 
a350a45
9d5b4c0
1e05e68
 
 
 
 
 
 
c6e9c8c
67f4e71
5bbb99c
9d5b4c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e05e68
9d5b4c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e05e68
9d5b4c0
 
 
 
 
 
 
 
eee0bf8
9d5b4c0
 
c6e9c8c
9d5b4c0
 
 
c6e9c8c
9d5b4c0
 
 
 
fe70438
d08fbc6
 
7cdc7d0
 
d08fbc6
e6be0c8
 
5bbb99c
e6be0c8
 
 
 
 
eee0bf8
fe70438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6be0c8
eee0bf8
 
 
 
cd9d84b
eee0bf8
 
 
cd9d84b
eee0bf8
e6be0c8
 
5bbb99c
 
 
1e05e68
eee0bf8
 
5bbb99c
 
 
 
 
 
cd9d84b
eee0bf8
1e05e68
8320ba9
 
d08fbc6
5bbb99c
 
 
 
 
 
 
 
 
 
 
cd9d84b
f6ebc4f
5bbb99c
eee0bf8
d443ad5
 
5bbb99c
 
 
d443ad5
5bbb99c
 
cd9d84b
5bbb99c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
from typing import List, Optional, Union

from .artifact import fetch_artifact
from .augmentors import Augmentor, NullAugmentor
from .card import TaskCard
from .collections_operators import GetLength
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
from .error_utils import UnitxtError
from .formats import Format, SystemFormat
from .logging_utils import get_logger
from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
from .operators import Set, StreamRefiner
from .recipe import Recipe
from .schema import FinalizeDataset
from .serializers import SingleTypeSerializer
from .settings_utils import get_constants, get_settings
from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
from .stream import MultiStream
from .system_prompts import EmptySystemPrompt, SystemPrompt
from .task import Task
from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList
from .type_utils import isoftype
from .utils import LRUCache

constants = get_constants()
settings = get_settings()
logger = get_logger()


# Used to give meaningful name to recipe steps
class CreateDemosPool(SeparateSplit):
    pass


class BaseRecipe(Recipe, SourceSequentialOperator):
    # Base parameters
    card: TaskCard = None
    task: Task = None
    template: Union[Template, List[Template], TemplatesList] = None
    system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
    format: Format = None
    serializer: Union[SingleTypeSerializer, List[SingleTypeSerializer]] = None

    # Additional parameters
    template_card_index: int = NonPositionalField(default=None)
    metrics: List[str] = NonPositionalField(default=None)
    postprocessors: List[str] = NonPositionalField(default=None)

    group_by: List[Union[str, List[str]]] = []

    loader_limit: int = None

    max_train_instances: int = None
    max_validation_instances: int = None
    max_test_instances: int = None

    train_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
    validation_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
    test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)

    demos_pool_size: int = None
    num_demos: Optional[Union[int, List[int]]] = 0
    demos_removed_from_data: bool = True

    demos_pool_name: str = "demos_pool"
    demos_taken_from: str = "train"
    demos_field: str = "demos"
    sampler: Sampler = None

    augmentor: Union[Augmentor, List[Augmentor]] = OptionalField(default=None)

    steps: List[StreamingOperator] = InternalField(default_factory=list)

    # shared class cache
    _demos_pool_cache = LRUCache(max_size=10)

    def before_process_multi_stream(self):
        super().before_process_multi_stream()

    @property
    def max_demos_size(self):
        if isinstance(self.num_demos, list):
            return max(self.num_demos)
        return self.num_demos

    def verify(self):
        super().verify()

        if self.task is None and self.card is None:
            raise ValueError("Set card or task in the recipe")

        if self.card is None and (
            self.num_demos > 0 or self.demos_pool_size is not None
        ):
            raise ValueError(
                "To use num_demos and demos_pool_size in recipe set a card."
            )

        if self.use_demos:
            if self.demos_pool_size is None or self.demos_pool_size < 1:
                raise ValueError(
                    "When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers."
                )
            if self.demos_pool_size < self.max_demos_size:
                raise ValueError(
                    f"num_demos (got: {self.max_demos_size}) should not exceed demos_pool_size (got: {self.demos_pool_size})"
                )
            if self.loader_limit and self.demos_pool_size > self.loader_limit:
                raise ValueError(
                    f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}"
                )

        if self.loader_limit:
            if self.max_test_instances and self.max_test_instances > self.loader_limit:
                raise ValueError(
                    f"max_test_instances should not exceed loader_limit ({self.loader_limit}), Got max_test_instances={self.max_test_instances}"
                )
            if (
                self.max_validation_instances
                and self.max_validation_instances > self.loader_limit
            ):
                raise ValueError(
                    f"max_validation_instances should not exceed loader_limit ({self.loader_limit}), Got max_validation_instances={self.max_validation_instances}"
                )
            if (
                self.max_train_instances
                and self.max_train_instances > self.loader_limit
            ):
                raise ValueError(
                    f"max_train_instances should not exceed loader_limit ({self.loader_limit}), Got max_train_instances={self.max_train_instances}"
                )
        if self.metrics is not None and not isinstance(self.metrics, List):
            raise ValueError(
                f"metrics must be a list of metrics.  Got metrics = {self.metrics}"
            )
        if self.postprocessors is not None and not isinstance(
            self.postprocessors, List
        ):
            raise ValueError(
                f"post processors must be a list of post processor.  Got postprocessors = {self.postprocessors}"
            )

        if self.format is not None and not isinstance(self.format, Format):
            raise ValueError(
                f"format parameter must be a list of of class derived from Format.  Got format = {self.format}"
            )
        if self.template is None:
            raise ValueError(
                "You must set in the recipe either `template`, `template_card_index`."
            )

        if isinstance(self.template, list):
            for template in self.template:
                self.verify_template(template)
        else:
            self.verify_template(self.template)

        if self.serializer is not None:
            if not isinstance(self.serializer, list):
                self.serializer = [self.serializer]
            self.template.serializer.add_serializers(self.serializer)

    def prepare_refiners(self):
        self.train_refiner.max_instances = self.max_train_instances
        self.train_refiner.apply_to_streams = ["train"]
        self.processing.steps.append(self.train_refiner)

        self.validation_refiner.max_instances = self.max_validation_instances
        self.validation_refiner.apply_to_streams = ["validation"]
        self.processing.steps.append(self.validation_refiner)

        self.test_refiner.max_instances = self.max_test_instances
        self.test_refiner.apply_to_streams = ["test"]
        self.processing.steps.append(self.test_refiner)

    def verify_template(self, template):
        if not isinstance(template, Template):
            raise ValueError(
                f"template argument must be an object of type Template. Got template = {template}"
            )

    def set_pipelines(self):
        self.loading = SequentialOperator(
            __description__="Loading the data from the data source."
        )
        self.metadata = SequentialOperator(
            __description__="Adding metadata (e.g. format, system prompt, template)  "
        )
        self.standardization = SequentialOperator(
            __description__="Standardizing the raw dataset fields to task field definition."
        )

        self.processing = SequentialOperator(
            __description__="Setting task fields (and selecting demos per sample if needed)."
        )
        self.verbalization = SequentialOperator()
        self.verbalization.__description__ = "Verbalizing the input to the model and gold references to the 'source', 'target' and 'references' fields."
        self.finalize = SequentialOperator()
        self.finalize.__description__ = "Adding post processors. Removing intermediate fields. Creating the final output dataset."

        self.steps = [
            self.loading,
            self.metadata,
            self.standardization,
            self.processing,
            self.metadata,
            self.verbalization,
            self.finalize,
        ]

        self.inference_instance = SequentialOperator()

        self.inference_instance.steps = [
            self.metadata,
            self.processing,
            self.metadata,
        ]

        self.inference_demos = SourceSequentialOperator()

        self.inference_demos.steps = [
            self.loading,
            self.metadata,
            self.standardization,
            self.processing,
            self.metadata,
        ]

        self.inference = SequentialOperator()

        self.inference.steps = [self.metadata, self.verbalization, self.finalize]

    def production_preprocess(self, task_instances):
        ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
        return list(self.inference_instance(ms)[constants.inference_stream])

    def production_demos_pool(self):
        if self.use_demos:
            demos_pool = self.__class__._demos_pool_cache.get(str(self), None)
            if demos_pool is None:
                demos_pool = list(self.inference_demos()[self.demos_pool_name])
                self.__class__._demos_pool_cache[str(self)] = demos_pool
            return demos_pool
        return []

    @property
    def has_custom_demos_pool(self):
        return self.demos_pool_size is not None and self.demos_pool_size > 0

    @property
    def use_demos(self):
        return self.num_demos is not None and self.max_demos_size > 0

    def produce(self, task_instances):
        """Use the recipe in production to produce model ready query from standard task instance."""
        self.before_process_multi_stream()
        streams = {
            constants.inference_stream: self.production_preprocess(task_instances),
        }
        if self.use_demos:
            streams[self.demos_pool_name] = self.production_demos_pool()
        multi_stream = MultiStream.from_iterables(streams)
        multi_stream = self.inference(multi_stream)
        return list(multi_stream[constants.inference_stream])

    def reset(self):
        self.reset_pipeline()

    def reset_pipeline(self):
        if self.format is None:
            if settings.default_format is not None:
                self.format, _ = fetch_artifact(settings.default_format)
            else:
                self.format = SystemFormat()

        if self.card and self.card.preprocess_steps is None:
            self.card.preprocess_steps = []

        if self.task is None:
            self.task = self.card.task

        self.set_pipelines()

        if self.card is not None:
            loader = self.card.loader
            if self.loader_limit:
                loader.loader_limit = self.loader_limit
                logger.info(f"Loader line limit was set to  {self.loader_limit}")
            self.loading.steps.append(loader)

            # This is required in case loader_limit is not enforced by the loader
            if self.loader_limit:
                self.loading.steps.append(
                    StreamRefiner(max_instances=self.loader_limit)
                )

        self.metadata.steps.append(
            Set(
                fields={
                    "recipe_metadata/system_prompt": self.system_prompt,
                    "recipe_metadata/format": self.format,
                }
            )
        )

        if self.card:
            self.standardization.steps.extend(self.card.preprocess_steps)

        self.processing.steps.append(self.task)

        if self.augmentor is not None and not isoftype(self.augmentor, NullAugmentor):
            if (
                self.card.task.augmentable_inputs is None
                or len(self.task.augmentable_inputs) == 0
            ):
                raise UnitxtError(
                    f"You specified augmentor in the recipe but the got task without augmentable_inputs: {self.task}"
                )

            if not isinstance(self.augmentor, list):
                self.augmentor = [self.augmentor]

            for augmentor in self.augmentor:
                augmentor.set_fields(self.card.task.augmentable_inputs)
                self.processing.steps.append(augmentor)

        if self.has_custom_demos_pool:
            self.processing.steps.append(
                CreateDemosPool(
                    from_split=self.demos_taken_from,
                    to_split_names=[self.demos_pool_name, self.demos_taken_from],
                    to_split_sizes=[int(self.demos_pool_size)],
                    remove_targets_from_source_split=self.demos_removed_from_data,
                )
            )

        if self.use_demos:
            if self.sampler is None:
                if self.card.sampler is None:
                    raise ValueError(
                        "Unexpected None value for card.sampler. "
                        "To use num_demos > 0, please set a sampler on the TaskCard."
                    )
                self.sampler = self.card.sampler

        self.prepare_refiners()

        if self.use_demos:
            if isinstance(self.num_demos, int):
                self.verbalization.steps.append(
                    ConstantSizeSample(
                        from_stream=self.demos_pool_name,
                        to_field=self.demos_field,
                        sampler=self.sampler,
                        sample_size=self.num_demos,
                    )
                )
                self.verbalization.steps.append(
                    Set(fields={"recipe_metadata/num_demos": self.num_demos})
                )

            elif isinstance(self.num_demos, list):
                self.verbalization.steps.append(
                    RandomSizeSample(
                        from_stream=self.demos_pool_name,
                        to_field=self.demos_field,
                        sampler=self.sampler,
                        sample_sizes=self.num_demos,
                    )
                )
                self.verbalization.steps.append(
                    GetLength(field="demos", to_field="recipe_metadata/num_demos")
                )
            else:
                raise ValueError("num_demos must be int or List[int]")

            if isinstance(self.template, list):
                self.verbalization.steps.append(
                    ApplyRandomTemplate(
                        templates=self.template, demos_field=self.demos_field
                    )
                )
            else:
                self.verbalization.steps.append(
                    ApplySingleTemplate(
                        template=self.template, demos_field=self.demos_field
                    )
                )
        else:
            self.verbalization.steps.append(
                Set(fields={"recipe_metadata/num_demos": 0})
            )
            if isinstance(self.template, list):
                self.verbalization.steps.append(
                    ApplyRandomTemplate(templates=self.template)
                )
            else:
                self.verbalization.steps.append(
                    ApplySingleTemplate(template=self.template)
                )

        self.verbalization.steps.append(self.system_prompt)
        self.verbalization.steps.append(self.format)

        if self.postprocessors is not None:
            self.finalize.steps.append(
                Set(fields={"postprocessors": self.postprocessors})
            )

        if self.metrics is not None:
            self.finalize.steps.append(Set(fields={"metrics": self.metrics}))

        self.finalize.steps.append(FinalizeDataset(group_by=self.group_by))

    def prepare(self):
        if isinstance(self.template, TemplatesList):
            self.template = self.template.items
        self.reset_pipeline()


class StandardRecipeWithIndexes(BaseRecipe):
    template_card_index: int = None

    def prepare(self):
        assert (
            self.template_card_index is None or self.template is None
        ), f"Specify either template ({self.template}) or template_card_index ({self.template_card_index}) but not both"

        if self.template_card_index is None and self.template is None:
            if self.card is not None:
                self.template_card_index = (
                    0
                    if isinstance(self.card.templates, list)
                    else next(iter(self.card.templates.keys()))
                )
                logger.warning(
                    "Template was not specified in recipe, using the first template from the card by default."
                )
            else:
                raise ValueError(
                    "Specify a template or template_card_index, or a card to get a default template from."
                )

        if self.template_card_index is not None:
            try:
                self.template = self.card.templates[self.template_card_index]
            except Exception as e:
                if isinstance(self.card.templates, dict):
                    options = list(self.card.templates.keys())
                else:
                    options = list(range(0, len(self.card.templates)))
                raise ValueError(
                    f"card_template_index '{self.template_card_index}' is not defined in card. Possible card_template_index options: {options}"
                ) from e

        super().prepare()


class StandardRecipe(StandardRecipeWithIndexes):
    """This class represents a standard recipe for data processing and preparation.

    This class can be used to prepare a recipe.
    with all necessary steps, refiners and renderers included. It allows to set various
    parameters and steps in a sequential manner for preparing the recipe.

    Attributes:
        card (TaskCard): TaskCard object associated with the recipe.
        template (Template, optional): Template object to be used for the recipe.
        system_prompt (SystemPrompt, optional): SystemPrompt object to be used for the recipe.
        loader_limit (int, optional): Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
        format (SystemFormat, optional): SystemFormat object to be used for the recipe.
        metrics (List[str]): list of catalog metrics to use with this recipe.
        postprocessors (List[str]): list of catalog processors to apply at post processing. (Not recommended to use from here)
        group_by (List[Union[str, List[str]]]): list of task_data or metadata keys to group global scores by.
        train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
        max_train_instances (int, optional): Maximum training instances for the refiner.
        validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
        max_validation_instances (int, optional): Maximum validation instances for the refiner.
        test_refiner (StreamRefiner, optional): Test refiner to be used in the recipe.
        max_test_instances (int, optional): Maximum test instances for the refiner.
        demos_pool_size (int, optional): Size of the demos pool.
        num_demos (int, optional): Number of demos to be used.
        demos_pool_name (str, optional): Name of the demos pool. Default is "demos_pool".
        demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train".
        demos_field (str, optional): Field name for demos. Default is "demos".
        demos_removed_from_data (bool, optional): whether to remove the demos from the source data, Default is True
        sampler (Sampler, optional): The Sampler used to select the demonstrations when num_demos > 0.
        steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
        augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text
        instruction_card_index (int, optional): Index of instruction card to be used for preparing the recipe.
        template_card_index (int, optional): Index of template card to be used for preparing the recipe.

    Methods:
        prepare(): This overridden method is used for preparing the recipe
        by arranging all the steps, refiners, and renderers in a sequential manner.

    Raises:
        AssertionError: If both template and template_card_index are specified at the same time.
    """

    pass