File size: 10,791 Bytes
c6e9c8c 1e05e68 eee0bf8 1e05e68 c6e9c8c 1e05e68 c6e9c8c eee0bf8 c6e9c8c e6be0c8 1e05e68 c6e9c8c eee0bf8 5bbb99c eee0bf8 5bbb99c c6e9c8c 5bbb99c c6e9c8c eee0bf8 c6e9c8c 1e05e68 5bbb99c 1e05e68 eee0bf8 1e05e68 eee0bf8 1e05e68 eee0bf8 1e05e68 eee0bf8 1e05e68 5bbb99c 67f4e71 c6e9c8c eee0bf8 1e05e68 eee0bf8 c6e9c8c eee0bf8 c6e9c8c eee0bf8 c6e9c8c 5bbb99c 1e05e68 c6e9c8c 1e05e68 c6e9c8c 67f4e71 5bbb99c 1e05e68 eee0bf8 1e05e68 c6e9c8c e6be0c8 5bbb99c e6be0c8 eee0bf8 e6be0c8 eee0bf8 e6be0c8 5bbb99c 1e05e68 eee0bf8 5bbb99c eee0bf8 1e05e68 5bbb99c eee0bf8 5bbb99c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
from typing import List
from .card import TaskCard
from .dataclass import Field, InternalField, OptionalField
from .formats import Format, SystemFormat
from .instructions import EmptyInstruction, Instruction
from .logging_utils import get_logger
from .operator import SourceSequentialOperator, StreamingOperator
from .operators import (
Augmentor,
NullAugmentor,
StreamRefiner,
)
from .recipe import Recipe
from .schema import ToUnitxtGroup
from .splitters import Sampler, SeparateSplit, SpreadSplit
from .templates import Template
logger = get_logger()
# Used to give meaningful name to recipe steps
class CreateDemosPool(SeparateSplit):
pass
class AddDemosField(SpreadSplit):
pass
class BaseRecipe(Recipe, SourceSequentialOperator):
card: TaskCard
template: Template = None
instruction: Instruction = Field(default_factory=EmptyInstruction)
format: Format = Field(default_factory=SystemFormat)
loader_limit: int = None
max_train_instances: int = None
max_validation_instances: int = None
max_test_instances: int = None
train_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
validation_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
demos_pool_size: int = None
num_demos: int = 0
demos_pool_name: str = "demos_pool"
demos_taken_from: str = "train"
demos_field: str = "demos"
sampler: Sampler = None
augmentor: Augmentor = OptionalField(default_factory=NullAugmentor)
steps: List[StreamingOperator] = InternalField(default_factory=list)
def before_process_multi_stream(self):
super().before_process_multi_stream()
if self.sampler: # e.g. when num_demos is 0, the sampler may not be initialized
self.sampler.init_new_random_generator()
def verify(self):
super().verify()
if self.num_demos > 0:
if self.demos_pool_size is None or self.demos_pool_size < 1:
raise ValueError(
"When using demonstrations both num_demos and demos_pool_size should be assigned with postive integers."
)
if self.demos_pool_size < self.num_demos:
raise ValueError(
f"num_demos (got: {self.num_demos}) should not exceed demos_pool_size (got: {self.demos_pool_size})"
)
if self.loader_limit and self.demos_pool_size > self.loader_limit:
raise ValueError(
f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}"
)
if self.loader_limit:
if self.max_test_instances and self.max_test_instances > self.loader_limit:
raise ValueError(
f"max_test_instances should not exceed loader_limit ({self.loader_limit}), Got max_test_instances={self.max_test_instances}"
)
if (
self.max_validation_instances
and self.max_validation_instances > self.loader_limit
):
raise ValueError(
f"max_validation_instances should not exceed loader_limit ({self.loader_limit}), Got max_validation_instances={self.max_validation_instances}"
)
if (
self.max_train_instances
and self.max_train_instances > self.loader_limit
):
raise ValueError(
f"max_train_instances should not exceed loader_limit ({self.loader_limit}), Got max_train_instances={self.max_train_instances}"
)
def prepare_refiners(self):
self.train_refiner.max_instances = self.max_train_instances
self.train_refiner.apply_to_streams = ["train"]
if self.max_train_instances is not None:
self.steps.append(self.train_refiner)
self.validation_refiner.max_instances = self.max_validation_instances
self.validation_refiner.apply_to_streams = ["validation"]
if self.max_validation_instances is not None:
self.steps.append(self.validation_refiner)
self.test_refiner.max_instances = self.max_test_instances
self.test_refiner.apply_to_streams = ["test"]
if self.max_test_instances is not None:
self.steps.append(self.test_refiner)
def prepare(self):
self.steps = [
self.card.loader,
]
if self.loader_limit:
self.card.loader.loader_limit = self.loader_limit
logger.info(f"Loader line limit was set to {self.loader_limit}")
self.steps.append(StreamRefiner(max_instances=self.loader_limit))
if self.card.preprocess_steps is not None:
self.steps.extend(self.card.preprocess_steps)
self.steps.append(self.card.task)
if self.augmentor.augment_task_input:
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
self.steps.append(self.augmentor)
if self.demos_pool_size is not None:
self.steps.append(
CreateDemosPool(
from_split=self.demos_taken_from,
to_split_names=[self.demos_pool_name, self.demos_taken_from],
to_split_sizes=[int(self.demos_pool_size)],
)
)
if self.num_demos > 0:
if self.sampler is None:
if self.card.sampler is None:
raise ValueError(
"Unexpected None value for card.sampler. "
"To use num_demos > 0, please set a sampler on the TaskCard."
)
self.sampler = self.card.sampler
self.sampler.set_size(self.num_demos)
self.prepare_refiners()
self.steps.append(self.template)
if self.num_demos > 0:
self.steps.append(
AddDemosField(
source_stream=self.demos_pool_name,
target_field=self.demos_field,
sampler=self.sampler,
)
)
self.steps.append(self.instruction)
self.steps.append(self.format)
if self.augmentor.augment_model_input:
self.steps.append(self.augmentor)
postprocessors = self.template.get_postprocessors()
self.steps.append(
ToUnitxtGroup(
group="unitxt",
metrics=self.card.task.metrics,
postprocessors=postprocessors,
)
)
class StandardRecipeWithIndexes(BaseRecipe):
instruction_card_index: int = None
template_card_index: int = None
def prepare(self):
assert (
self.template_card_index is None or self.template is None
), f"Specify either template ({self.template}) or template_card_index ({self.template_card_index}) but not both"
assert not (
self.template_card_index is None and self.template is None
), "Specify either template or template_card_index in card"
if self.template_card_index is not None:
try:
self.template = self.card.templates[self.template_card_index]
except Exception as e:
if isinstance(self.card.templates, dict):
options = self.card.templates.keys()
else:
options = list(range(0, len(self.card.templates)))
raise ValueError(
f"card_template_index '{self.template_card_index}' is not in card. Available options: {options}"
) from e
assert (
self.instruction_card_index is None or self.instruction is None
), "Specify either instruction or instruction_card_index"
if self.instruction_card_index is not None:
self.instruction = self.card.instructions[int(self.instruction_card_index)]
super().prepare()
class StandardRecipe(StandardRecipeWithIndexes):
"""This class represents a standard recipe for data processing and preparation.
This class can be used to prepare a recipe.
with all necessary steps, refiners and renderers included. It allows to set various
parameters and steps in a sequential manner for preparing the recipe.
Attributes:
card (TaskCard): TaskCard object associated with the recipe.
template (Template, optional): Template object to be used for the recipe.
instruction (Instruction, optional): Instruction object to be used for the recipe.
loader_limit (int, optional): Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
format (SystemFormat, optional): SystemFormat object to be used for the recipe.
train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
max_train_instances (int, optional): Maximum training instances for the refiner.
validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
max_validation_instances (int, optional): Maximum validation instances for the refiner.
test_refiner (StreamRefiner, optional): Test refiner to be used in the recipe.
max_test_instances (int, optional): Maximum test instances for the refiner.
demos_pool_size (int, optional): Size of the demos pool.
num_demos (int, optional): Number of demos to be used.
demos_pool_name (str, optional): Name of the demos pool. Default is "demos_pool".
demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train".
demos_field (str, optional): Field name for demos. Default is "demos".
sampler (Sampler, optional): Sampler object to be used in the recipe.
steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text
instruction_card_index (int, optional): Index of instruction card to be used
for preparing the recipe.
template_card_index (int, optional): Index of template card to be used for
preparing the recipe.
Methods:
prepare(): This overridden method is used for preparing the recipe
by arranging all the steps, refiners, and renderers in a sequential manner.
Raises:
AssertionError: If both template and template_card_index, or instruction and instruction_card_index
are specified at the same time.
"""
pass
|