File size: 14,739 Bytes
99fde4e e203384 f6ebc4f 2341544 f6ebc4f 04d2454 80500e3 e7c76e5 2341544 99fde4e e7c76e5 100c2eb f6ebc4f fe70438 04d2454 e7c76e5 99fde4e e7c76e5 68d64cc e7c76e5 e755967 e7c76e5 99fde4e 04d2454 99fde4e 571af6d 99fde4e 04d2454 571af6d 04d2454 99fde4e 04d2454 571af6d 99fde4e e755967 99fde4e e7c76e5 e755967 e7c76e5 9d5b4c0 e203384 9d5b4c0 e203384 04d2454 9d5b4c0 04d2454 e203384 e7c76e5 571af6d f6ebc4f 100c2eb f6ebc4f 100c2eb 571af6d e7c76e5 f6ebc4f 9d5b4c0 f6ebc4f 9d5b4c0 f6ebc4f 9d5b4c0 f6ebc4f 9d5b4c0 f6ebc4f 04d2454 9d5b4c0 04d2454 f6ebc4f e7c76e5 f6ebc4f 9d5b4c0 f6ebc4f 9d5b4c0 e7c76e5 78a0600 80500e3 78a0600 80500e3 571af6d 78a0600 80500e3 78a0600 707625d f6ebc4f 04d2454 80500e3 04d2454 b462f85 04d2454 f6ebc4f 80500e3 f6ebc4f 707625d 04d2454 707625d 04d2454 707625d 78a0600 707625d 04d2454 707625d 571af6d 78a0600 707625d 78a0600 04d2454 f6ebc4f 9d5b4c0 f6ebc4f 04d2454 80500e3 9d5b4c0 f6ebc4f 78a0600 9d5b4c0 80500e3 9d5b4c0 80500e3 78a0600 9d5b4c0 78a0600 9d5b4c0 80500e3 78a0600 f6ebc4f 78a0600 f6ebc4f 78a0600 9d5b4c0 e7c76e5 04d2454 e7c76e5 9d5b4c0 e7c76e5 04d2454 9d5b4c0 04d2454 fe70438 04d2454 571af6d 9d5b4c0 100c2eb 9d5b4c0 04d2454 100c2eb 9d5b4c0 04d2454 9d5b4c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
import itertools
from abc import abstractmethod
from difflib import get_close_matches
from typing import Dict, List, Optional
from .artifact import Artifact
from .dict_utils import dict_get
from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
from .random_utils import new_random_generator
from .split_utils import (
parse_random_mix_string,
parse_slices_string,
random_mix_streams,
rename_split,
slice_streams,
)
from .stream import EmptyStreamError, FaultyStreamError, MultiStream
from .type_utils import isoftype
from .utils import recursive_copy
class Splitter(MultiStreamOperator):
pass
class RenameSplits(Splitter):
mapper: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
generators = rename_split(multi_stream, self.mapper)
return MultiStream(generators)
class SplitRandomMix(Splitter):
"""Splits a multistream into new streams (splits), whose names, source input stream, and amount of instances, are specified by arg 'mix'.
The keys of arg 'mix', are the names of the new streams, the values are of the form: 'name-of-source-stream[percentage-of-source-stream]'
Each input instance, of any input stream, is selected exactly once for inclusion in any of the output streams.
Examples:
When processing a multistream made of two streams whose names are 'train' and 'test', by
SplitRandomMix(mix = { "train": "train[99%]", "validation": "train[1%]", "test": "test" })
the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
Output stream 'train' is made of randomly selected 99% of the instances of input stream 'train',
output stream 'validation' is made of the remaining 1% instances of input 'train', and output stream 'test' is made
of the whole of input stream 'test'.
When processing the above input multistream by
SplitRandomMix(mix = { "train": "train[50%]+test[0.1]", "validation": "train[50%]+test[0.2]", "test": "test[0.7]" })
the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
Output stream 'train' is made of randomly selected 50% of the instances of input stream 'train' + randomly selected
0.1 (i.e., 10%) of the instances of input stream 'test'.
Output stream 'validation' is made of the remaining 50% instances of input 'train'+ randomly selected 0.2 (i.e.,
20%) of the original instances of input 'test', that were not selected for output 'train',
and output stream 'test' is made of the remaining instances of input 'test'.
"""
mix: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
generators = random_mix_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)
class SeparateSplit(Splitter):
"""Separates a split (e.g. train) into several splits (e.g. train1, train2).
sizes must indicate the size of every split except the last. If no size is give for the last split,
it includes all the examples not allocated to any split.
"""
from_split: str
to_split_names: List[str]
to_split_sizes: List[int]
remove_targets_from_source_split: bool = True
def verify(self):
assert (
len(self.to_split_names) == len(self.to_split_sizes)
or len(self.to_split_names) == len(self.to_split_sizes) + 1
), f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}"
return super().verify()
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {
key: {key: [(None, None)]}
for key in multi_stream.keys()
if not self.remove_targets_from_source_split or key != self.from_split
}
so_far = 0
for name, size in itertools.zip_longest(
self.to_split_names, self.to_split_sizes
):
if self.remove_targets_from_source_split or name != self.from_split:
mapping[name] = {self.from_split: [(so_far, size)]}
if size:
so_far += size
generators = slice_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)
class SliceSplit(Splitter):
slices: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
generators = slice_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)
def get_random_generator_based_on_instance(instance):
return new_random_generator(sub_seed={**instance["input_fields"]})
class Sampler(Artifact):
@abstractmethod
def sample(
self,
sample_size: int,
instances_pool: List[Dict[str, object]],
instance: Dict[str, object],
) -> List[Dict[str, object]]:
pass
def filter_source_by_instance(
self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
) -> List[Dict[str, object]]:
if "input_fields" not in instance:
raise ValueError(f"'input_fields' field is missing from '{instance}'.")
try:
return [
item
for item in instances_pool
if item["input_fields"] != instance["input_fields"]
]
except Exception as e:
raise e
class RandomSampler(Sampler):
"""Selects a random sample of instances."""
def sample(
self,
sample_size,
instances_pool: List[Dict[str, object]],
instance: Optional[Dict[str, object]],
) -> List[Dict[str, object]]:
instances_pool = list(instances_pool)
random_generator = get_random_generator_based_on_instance(instance)
return random_generator.sample(instances_pool, sample_size)
class FixedIndicesSampler(Sampler):
"""Selects a fix set of samples based on a list of indices."""
indices: List[int]
def verify(self):
assert isoftype(
self.indices, List[int]
), f"'indices' of {self.__class__.__name__} must be List[int]. Value {self.indices} is of type {type(self.indices)}"
super().verify()
def sample(
self,
sample_size,
instances_pool: List[Dict[str, object]],
instance: Optional[Dict[str, object]],
) -> List[Dict[str, object]]:
num_instances = len(instances_pool)
instances = []
for index in self.indices[0:sample_size]:
if index >= num_instances:
raise ValueError(
f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})"
)
instances.append(instances_pool[index])
return instances
class CloseTextSampler(Sampler):
"""Selects the samples of instances which are the closest textual match to the given instance.
Comparison is done based on a given field in the instance.
"""
field: str
def sample(
self,
sample_size: int,
instances_pool: List[Dict[str, object]],
instance: Dict[str, object],
) -> List[Dict[str, object]]:
field = f"input_fields/{self.field}"
value = dict_get(instance, field)
instances_pool = list(instances_pool)
# Get 'sample_size' closest matchest texts based on field
options = []
for instance_in_pool in instances_pool:
options.append(dict_get(instance_in_pool, field))
closest_matches = get_close_matches(value, options, n=sample_size, cutoff=0)
# Randmly select 'sample_size' instances that are from the closest matches text
# (There may be multiple instance with same text in the given field, and the order returned is
# is also randomized )
instances_pool = [
instance_in_pool
for instance_in_pool in instances_pool
if dict_get(instance_in_pool, field) in closest_matches
]
random_generator = get_random_generator_based_on_instance(instance)
return random_generator.sample(instances_pool, sample_size)
class DiverseLabelsSampler(Sampler):
"""Selects a balanced sample of instances based on an output field.
(used for selecting demonstrations in-context learning)
The field must contain list of values e.g ['dog'], ['cat'], ['dog','cat','cow'].
The balancing is done such that each value or combination of values
appears as equals as possible in the samples.
The `choices` param is required and determines which values should be considered.
Example:
If choices is ['dog,'cat'] , then the following combinations will be considered.
['']
['cat']
['dog']
['dog','cat']
If the instance contains a value not in the 'choice' param, it is ignored. For example,
if choices is ['dog,'cat'] and the instance field is ['dog','cat','cow'], then 'cow' is ignored
then the instance is considered as ['dog','cat'].
Args:
sample_size - number of samples to extract
choices - name of input field that contains the list of values to balance on
labels - name of output field with labels that must be balanced
"""
choices: str = "choices"
labels: str = "labels"
include_empty_label: bool = True
def prepare(self):
super().prepare()
self.labels_cache = None
def exemplar_repr(self, exemplar):
if "input_fields" not in exemplar:
raise ValueError(f"'input_fields' field is missing from '{exemplar}'.")
inputs = exemplar["input_fields"]
if self.choices not in inputs:
raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.")
choices = inputs[self.choices]
if not isinstance(choices, list):
if isinstance(choices, str):
choices = [choices]
else:
raise ValueError(
f"Unexpected input choices value '{choices}'. Expected a list or a string."
)
if "reference_fields" not in exemplar:
raise ValueError(f"'reference_fields' field is missing from '{exemplar}'.")
outputs = exemplar["reference_fields"]
if self.labels not in outputs:
raise ValueError(f"'{self.labels}' field is missing from '{outputs}'.")
exemplar_outputs = exemplar["reference_fields"][self.labels]
if not isinstance(exemplar_outputs, list):
raise ValueError(
f"Unexpected exemplar_outputs value '{exemplar_outputs}'. Expected a list."
)
return str([choice for choice in choices if choice in exemplar_outputs])
def divide_by_repr(self, exemplars_pool):
labels = {}
for exemplar in exemplars_pool:
label_repr = self.exemplar_repr(exemplar)
if label_repr == "[]" and not self.include_empty_label:
continue
if label_repr not in labels:
labels[label_repr] = []
labels[label_repr].append(exemplar)
return labels
def sample(
self,
sample_size: int,
instances_pool: List[Dict[str, object]],
instance: Optional[Dict[str, object]],
) -> List[Dict[str, object]]:
if self.labels_cache is None:
self.labels_cache = self.divide_by_repr(instances_pool)
all_labels = list(self.labels_cache.keys())
random_generator = get_random_generator_based_on_instance(instance)
random_generator.shuffle(all_labels)
from collections import Counter
if sample_size > len(instances_pool):
raise ValueError(
f"Request sample size {sample_size} is greater than number of instances {len(instances_pool)}"
)
total_allocated = 0
allocations = Counter()
while total_allocated < sample_size:
for label in all_labels:
if total_allocated < sample_size:
if len(self.labels_cache[label]) - allocations[label] > 0:
allocations[label] += 1
total_allocated += 1
else:
break
result = []
for label, allocation in allocations.items():
sample = random_generator.sample(self.labels_cache[label], allocation)
result.extend(sample)
random_generator.shuffle(result)
return result
class Sample(InstanceOperatorWithMultiStreamAccess):
from_stream: str
to_field: str
sampler: Sampler
def prepare(self):
self.local_cache = None
self.sampler.prepare()
@abstractmethod
def get_sample_size(self, instance) -> int:
pass
def process(
self, instance: Dict[str, object], multi_stream: MultiStream
) -> Dict[str, object]:
sample_size = self.get_sample_size(instance)
try:
if self.local_cache is None:
self.local_cache = recursive_copy(list(multi_stream[self.from_stream]))
source_stream = self.local_cache
source_stream = self.sampler.filter_source_by_instance(
source_stream, instance
)
if len(source_stream) < sample_size:
raise ValueError(
f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}."
)
sampled_instances = self.sampler.sample(
sample_size=sample_size, instances_pool=source_stream, instance=instance
)
instance[self.to_field] = sampled_instances
return instance
except FaultyStreamError as e:
raise EmptyStreamError(
f"Unable to fetch instances from '{self.from_stream}' to '{self.to_field}', due to {e.__class__.__name__}: {e}"
) from e
class ConstantSizeSample(Sample):
sample_size: int
def get_sample_size(self, instance) -> int:
return self.sample_size
class RandomSizeSample(Sample):
sample_sizes: List[int]
def get_sample_size(self, instance) -> int:
random_generator = get_random_generator_based_on_instance(instance)
return random_generator.choice(self.sample_sizes)
|