File size: 14,739 Bytes
99fde4e
e203384
f6ebc4f
 
2341544
 
f6ebc4f
04d2454
80500e3
e7c76e5
 
 
2341544
99fde4e
e7c76e5
 
100c2eb
f6ebc4f
fe70438
04d2454
 
 
 
e7c76e5
 
99fde4e
 
 
 
 
 
 
 
e7c76e5
68d64cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7c76e5
 
 
 
 
e755967
e7c76e5
 
99fde4e
04d2454
 
99fde4e
 
 
 
 
 
 
571af6d
99fde4e
 
 
 
 
 
 
 
 
04d2454
 
 
571af6d
04d2454
99fde4e
04d2454
 
 
571af6d
 
99fde4e
 
 
e755967
99fde4e
 
e7c76e5
 
 
 
 
 
e755967
e7c76e5
 
9d5b4c0
 
e203384
 
9d5b4c0
e203384
04d2454
9d5b4c0
 
 
 
04d2454
e203384
e7c76e5
571af6d
 
 
f6ebc4f
 
100c2eb
 
f6ebc4f
 
 
100c2eb
 
 
571af6d
e7c76e5
 
f6ebc4f
 
 
 
9d5b4c0
f6ebc4f
 
 
 
9d5b4c0
 
f6ebc4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d5b4c0
f6ebc4f
 
 
 
 
 
9d5b4c0
f6ebc4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d2454
9d5b4c0
 
 
 
04d2454
f6ebc4f
 
 
e7c76e5
f6ebc4f
 
 
 
 
9d5b4c0
f6ebc4f
 
 
 
 
 
 
 
9d5b4c0
 
e7c76e5
 
78a0600
80500e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78a0600
80500e3
571af6d
78a0600
 
 
80500e3
78a0600
707625d
f6ebc4f
 
 
04d2454
80500e3
04d2454
 
b462f85
 
 
 
 
 
04d2454
f6ebc4f
 
 
80500e3
 
 
f6ebc4f
707625d
04d2454
707625d
04d2454
 
707625d
78a0600
707625d
04d2454
707625d
 
571af6d
 
78a0600
 
707625d
78a0600
 
04d2454
f6ebc4f
9d5b4c0
f6ebc4f
 
04d2454
80500e3
 
 
9d5b4c0
f6ebc4f
78a0600
 
9d5b4c0
80500e3
9d5b4c0
80500e3
78a0600
 
 
9d5b4c0
78a0600
9d5b4c0
80500e3
78a0600
 
 
 
 
 
 
f6ebc4f
78a0600
 
f6ebc4f
78a0600
 
 
9d5b4c0
 
 
 
e7c76e5
 
 
04d2454
e7c76e5
9d5b4c0
 
 
e7c76e5
04d2454
 
 
9d5b4c0
04d2454
 
fe70438
04d2454
 
571af6d
 
 
9d5b4c0
100c2eb
 
 
9d5b4c0
 
 
 
04d2454
100c2eb
 
9d5b4c0
04d2454
9d5b4c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import itertools
from abc import abstractmethod
from difflib import get_close_matches
from typing import Dict, List, Optional

from .artifact import Artifact
from .dict_utils import dict_get
from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
from .random_utils import new_random_generator
from .split_utils import (
    parse_random_mix_string,
    parse_slices_string,
    random_mix_streams,
    rename_split,
    slice_streams,
)
from .stream import EmptyStreamError, FaultyStreamError, MultiStream
from .type_utils import isoftype
from .utils import recursive_copy


class Splitter(MultiStreamOperator):
    pass


class RenameSplits(Splitter):
    mapper: Dict[str, str]

    def process(self, multi_stream: MultiStream) -> MultiStream:
        generators = rename_split(multi_stream, self.mapper)
        return MultiStream(generators)


class SplitRandomMix(Splitter):
    """Splits a multistream into new streams (splits), whose names, source input stream, and amount of instances, are specified by arg 'mix'.

    The keys of arg 'mix', are the names of the new streams, the values are of the form: 'name-of-source-stream[percentage-of-source-stream]'
    Each input instance, of any input stream, is selected exactly once for inclusion in any of the output streams.

    Examples:
    When processing a multistream made of two streams whose names are 'train' and 'test', by
    SplitRandomMix(mix =  { "train": "train[99%]",  "validation": "train[1%]",  "test": "test" })
    the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
    Output stream 'train' is made of randomly selected 99% of the instances of input stream 'train',
    output stream 'validation' is made of the remaining 1% instances of input 'train', and output stream 'test' is made
    of the whole of input stream 'test'.

    When processing the above input multistream by
    SplitRandomMix(mix =  { "train": "train[50%]+test[0.1]",  "validation": "train[50%]+test[0.2]",  "test": "test[0.7]" })
    the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
    Output stream 'train' is made of randomly selected 50% of the instances of input stream 'train' + randomly selected
    0.1 (i.e., 10%) of the instances of input stream 'test'.
    Output stream 'validation' is made of the remaining 50% instances of input 'train'+ randomly selected 0.2 (i.e.,
    20%) of the original instances of input 'test', that were not selected for output 'train',
    and output stream 'test' is made of the remaining instances of input 'test'.
    """

    mix: Dict[str, str]

    def process(self, multi_stream: MultiStream) -> MultiStream:
        mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
        generators = random_mix_streams(multi_stream, mapping)
        return MultiStream.from_generators(generators)


class SeparateSplit(Splitter):
    """Separates a split (e.g. train) into several splits (e.g. train1, train2).

    sizes must indicate the size of every split except the last. If no size is give for the last split,
     it includes all the examples not allocated to any split.
    """

    from_split: str
    to_split_names: List[str]
    to_split_sizes: List[int]
    remove_targets_from_source_split: bool = True

    def verify(self):
        assert (
            len(self.to_split_names) == len(self.to_split_sizes)
            or len(self.to_split_names) == len(self.to_split_sizes) + 1
        ), f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}"
        return super().verify()

    def process(self, multi_stream: MultiStream) -> MultiStream:
        mapping = {
            key: {key: [(None, None)]}
            for key in multi_stream.keys()
            if not self.remove_targets_from_source_split or key != self.from_split
        }
        so_far = 0
        for name, size in itertools.zip_longest(
            self.to_split_names, self.to_split_sizes
        ):
            if self.remove_targets_from_source_split or name != self.from_split:
                mapping[name] = {self.from_split: [(so_far, size)]}
            if size:
                so_far += size
        generators = slice_streams(multi_stream, mapping)
        return MultiStream.from_generators(generators)


class SliceSplit(Splitter):
    slices: Dict[str, str]

    def process(self, multi_stream: MultiStream) -> MultiStream:
        mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
        generators = slice_streams(multi_stream, mapping)
        return MultiStream.from_generators(generators)


def get_random_generator_based_on_instance(instance):
    return new_random_generator(sub_seed={**instance["input_fields"]})


class Sampler(Artifact):
    @abstractmethod
    def sample(
        self,
        sample_size: int,
        instances_pool: List[Dict[str, object]],
        instance: Dict[str, object],
    ) -> List[Dict[str, object]]:
        pass

    def filter_source_by_instance(
        self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
    ) -> List[Dict[str, object]]:
        if "input_fields" not in instance:
            raise ValueError(f"'input_fields' field is missing from '{instance}'.")
        try:
            return [
                item
                for item in instances_pool
                if item["input_fields"] != instance["input_fields"]
            ]
        except Exception as e:
            raise e


class RandomSampler(Sampler):
    """Selects a random sample of instances."""

    def sample(
        self,
        sample_size,
        instances_pool: List[Dict[str, object]],
        instance: Optional[Dict[str, object]],
    ) -> List[Dict[str, object]]:
        instances_pool = list(instances_pool)
        random_generator = get_random_generator_based_on_instance(instance)
        return random_generator.sample(instances_pool, sample_size)


class FixedIndicesSampler(Sampler):
    """Selects a fix set of samples based on a list of indices."""

    indices: List[int]

    def verify(self):
        assert isoftype(
            self.indices, List[int]
        ), f"'indices' of {self.__class__.__name__} must be List[int]. Value {self.indices} is of type {type(self.indices)}"
        super().verify()

    def sample(
        self,
        sample_size,
        instances_pool: List[Dict[str, object]],
        instance: Optional[Dict[str, object]],
    ) -> List[Dict[str, object]]:
        num_instances = len(instances_pool)

        instances = []
        for index in self.indices[0:sample_size]:
            if index >= num_instances:
                raise ValueError(
                    f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})"
                )
            instances.append(instances_pool[index])
        return instances


class CloseTextSampler(Sampler):
    """Selects the samples of instances which are the closest textual match to the given instance.

    Comparison is done based on a given field in the instance.

    """

    field: str

    def sample(
        self,
        sample_size: int,
        instances_pool: List[Dict[str, object]],
        instance: Dict[str, object],
    ) -> List[Dict[str, object]]:
        field = f"input_fields/{self.field}"
        value = dict_get(instance, field)

        instances_pool = list(instances_pool)

        # Get 'sample_size'  closest matchest texts based on field
        options = []
        for instance_in_pool in instances_pool:
            options.append(dict_get(instance_in_pool, field))
        closest_matches = get_close_matches(value, options, n=sample_size, cutoff=0)
        # Randmly select 'sample_size' instances that are from the closest matches text
        # (There may be multiple instance with same text in the given field, and the order returned is
        # is also randomized )
        instances_pool = [
            instance_in_pool
            for instance_in_pool in instances_pool
            if dict_get(instance_in_pool, field) in closest_matches
        ]
        random_generator = get_random_generator_based_on_instance(instance)
        return random_generator.sample(instances_pool, sample_size)


class DiverseLabelsSampler(Sampler):
    """Selects a balanced sample of instances based on an output field.

    (used for selecting demonstrations in-context learning)

    The field must contain list of values e.g ['dog'], ['cat'], ['dog','cat','cow'].
    The balancing is done such that each value or combination of values
    appears as equals as possible in the samples.

    The `choices` param is required and determines which values should be considered.

    Example:
        If choices is ['dog,'cat'] , then the following combinations will be considered.
        ['']
        ['cat']
        ['dog']
        ['dog','cat']

        If the instance contains a value not in the 'choice' param, it is ignored. For example,
        if choices is ['dog,'cat'] and the instance field is ['dog','cat','cow'], then 'cow' is ignored
        then the instance is considered as ['dog','cat'].

    Args:
        sample_size - number of samples to extract
        choices - name of input field that contains the list of values to balance on
        labels - name of output field with labels that must be balanced


    """

    choices: str = "choices"
    labels: str = "labels"
    include_empty_label: bool = True

    def prepare(self):
        super().prepare()
        self.labels_cache = None

    def exemplar_repr(self, exemplar):
        if "input_fields" not in exemplar:
            raise ValueError(f"'input_fields' field is missing from '{exemplar}'.")
        inputs = exemplar["input_fields"]
        if self.choices not in inputs:
            raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.")
        choices = inputs[self.choices]
        if not isinstance(choices, list):
            if isinstance(choices, str):
                choices = [choices]
            else:
                raise ValueError(
                    f"Unexpected input choices value '{choices}'. Expected a list or a string."
                )

        if "reference_fields" not in exemplar:
            raise ValueError(f"'reference_fields' field is missing from '{exemplar}'.")
        outputs = exemplar["reference_fields"]
        if self.labels not in outputs:
            raise ValueError(f"'{self.labels}' field is missing from '{outputs}'.")

        exemplar_outputs = exemplar["reference_fields"][self.labels]
        if not isinstance(exemplar_outputs, list):
            raise ValueError(
                f"Unexpected exemplar_outputs value '{exemplar_outputs}'. Expected a list."
            )

        return str([choice for choice in choices if choice in exemplar_outputs])

    def divide_by_repr(self, exemplars_pool):
        labels = {}
        for exemplar in exemplars_pool:
            label_repr = self.exemplar_repr(exemplar)
            if label_repr == "[]" and not self.include_empty_label:
                continue
            if label_repr not in labels:
                labels[label_repr] = []
            labels[label_repr].append(exemplar)
        return labels

    def sample(
        self,
        sample_size: int,
        instances_pool: List[Dict[str, object]],
        instance: Optional[Dict[str, object]],
    ) -> List[Dict[str, object]]:
        if self.labels_cache is None:
            self.labels_cache = self.divide_by_repr(instances_pool)
        all_labels = list(self.labels_cache.keys())
        random_generator = get_random_generator_based_on_instance(instance)
        random_generator.shuffle(all_labels)
        from collections import Counter

        if sample_size > len(instances_pool):
            raise ValueError(
                f"Request sample size {sample_size} is greater than number of instances {len(instances_pool)}"
            )
        total_allocated = 0
        allocations = Counter()

        while total_allocated < sample_size:
            for label in all_labels:
                if total_allocated < sample_size:
                    if len(self.labels_cache[label]) - allocations[label] > 0:
                        allocations[label] += 1
                        total_allocated += 1
                else:
                    break

        result = []
        for label, allocation in allocations.items():
            sample = random_generator.sample(self.labels_cache[label], allocation)
            result.extend(sample)

        random_generator.shuffle(result)
        return result


class Sample(InstanceOperatorWithMultiStreamAccess):
    from_stream: str
    to_field: str
    sampler: Sampler

    def prepare(self):
        self.local_cache = None
        self.sampler.prepare()

    @abstractmethod
    def get_sample_size(self, instance) -> int:
        pass

    def process(
        self, instance: Dict[str, object], multi_stream: MultiStream
    ) -> Dict[str, object]:
        sample_size = self.get_sample_size(instance)
        try:
            if self.local_cache is None:
                self.local_cache = recursive_copy(list(multi_stream[self.from_stream]))

            source_stream = self.local_cache
            source_stream = self.sampler.filter_source_by_instance(
                source_stream, instance
            )
            if len(source_stream) < sample_size:
                raise ValueError(
                    f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}."
                )
            sampled_instances = self.sampler.sample(
                sample_size=sample_size, instances_pool=source_stream, instance=instance
            )
            instance[self.to_field] = sampled_instances
            return instance
        except FaultyStreamError as e:
            raise EmptyStreamError(
                f"Unable to fetch instances from '{self.from_stream}' to '{self.to_field}', due to {e.__class__.__name__}: {e}"
            ) from e


class ConstantSizeSample(Sample):
    sample_size: int

    def get_sample_size(self, instance) -> int:
        return self.sample_size


class RandomSizeSample(Sample):
    sample_sizes: List[int]

    def get_sample_size(self, instance) -> int:
        random_generator = get_random_generator_based_on_instance(instance)
        return random_generator.choice(self.sample_sizes)