File size: 6,692 Bytes
cbca7b8 b462f85 5852323 300a7be 2d210f5 cc653d8 100c2eb b462f85 5c545d2 5852323 cbca7b8 b462f85 5c545d2 cbca7b8 d08fbc6 b462f85 cbca7b8 5c545d2 d08fbc6 300a7be 5c545d2 cbca7b8 5c545d2 82055e6 d08fbc6 b462f85 82055e6 b462f85 82055e6 5852323 5c545d2 cbca7b8 100c2eb b462f85 cbca7b8 5852323 5c545d2 cbca7b8 b462f85 5c545d2 cbca7b8 d08fbc6 b462f85 d08fbc6 b462f85 cbca7b8 5c545d2 d08fbc6 b462f85 5c545d2 b462f85 cbca7b8 d08fbc6 82055e6 ef1f482 b462f85 d08fbc6 82055e6 d08fbc6 5c545d2 5852323 cbca7b8 b462f85 5c545d2 cbca7b8 d08fbc6 b462f85 cbca7b8 5c545d2 d08fbc6 b462f85 d08fbc6 5c545d2 cbca7b8 d08fbc6 cbca7b8 d08fbc6 17a636b d08fbc6 b462f85 d08fbc6 b462f85 5c545d2 cbca7b8 82055e6 cbca7b8 cc653d8 17a636b d08fbc6 17a636b b462f85 cbca7b8 b462f85 d08fbc6 cbca7b8 b462f85 cbca7b8 b462f85 82055e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from abc import abstractmethod
from typing import Dict, Generator, List, Optional, Union
from .dataclass import NonPositionalField
from .operator import SourceOperator
from .random_utils import new_random_generator
from .stream import DynamicStream, MultiStream
from .type_utils import isoftype
class BaseFusion(SourceOperator):
"""BaseFusion operator that combines multiple multistreams into one.
Args:
subsets: a dict of named SourceOperator objects (each to yield a MultiStream) or a list thereof,
each is specified along with its input, so can generate a MultiStream
include_splits: List of splits to include from each input MultiStream.
If None, all splits are included.
"""
subsets: Union[List[SourceOperator], Dict[str, SourceOperator]]
include_splits: Optional[List[str]] = NonPositionalField(default=None)
@abstractmethod
def fusion_generator(self, split) -> Generator:
pass
def prepare_subsets(self):
assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
self.subsets, List[SourceOperator]
)
self.named_subsets = {}
if isinstance(self.subsets, list):
for i in range(len(self.subsets)):
self.named_subsets[i] = self.subsets[i]
else:
for name, origin in self.subsets.items():
try:
self.named_subsets[name] = origin
except Exception as e:
raise RuntimeError(f"Exception in subset: {name}") from e
def splits(self) -> List[str]:
self.prepare_subsets()
if self.include_splits is not None:
return self.include_splits
return ["train", "test", "validation"]
def process(
self,
) -> MultiStream:
result = {}
for split in self.splits():
result[split] = DynamicStream(
self.fusion_generator, gen_kwargs={"split": split}
)
return MultiStream(result)
class FixedFusion(BaseFusion):
"""FixedFusion operator that combines multiple multistreams into one, limiting the number of instances taken from each split of each input multistream.
Args:
subsets: Dict of named SourceOperator objects (each to yield a MultiStream), or a list thereof
splits: List of splits (stream_names) to include, over all input multistreams. If None, all splits are included.
max_instances_per_subset: Number of instances to take from each input split of each input multistream.
If None, all instances of each split (that is specified in include_splits) are included in the result.
"""
max_instances_per_subset: Optional[int] = None
def prepare(self):
super().prepare()
# flake8: noqa: C901
def fusion_generator(self, split) -> Generator:
for origin_name, origin in self.named_subsets.items():
multi_stream = origin()
if split not in multi_stream:
continue
emitted_from_this_split = 0
try:
for instance in multi_stream[split]:
if (
self.max_instances_per_subset is not None
and emitted_from_this_split >= self.max_instances_per_subset
):
break
if isinstance(origin_name, str):
if "subset" not in instance:
instance["subset"] = []
instance["subset"].insert(0, origin_name)
emitted_from_this_split += 1
yield instance
except Exception as e:
raise RuntimeError(f"Exception in subset: {origin_name}") from e
class WeightedFusion(BaseFusion):
"""Fusion operator that combines multiple MultiStream-s.
Args:
subsets: Dict of named MultiStream objects, or a list thereof
weights: Dict of named weights for each origin, or a list thereof
max_total_examples: Total number of instances to return per returned split.
If None, all instances are returned
"""
subsets: Union[Dict[str, SourceOperator], List[SourceOperator]] = None
weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
max_total_samples: int = None
def verify(self):
super().verify()
assert self.subsets is not None, "subsets must be specified"
assert self.weights is not None, "weights must be specified"
assert len(self.subsets) == len(
self.weights
), "subsets and weights must have the same length"
assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
self.subsets, List[SourceOperator]
)
assert isoftype(self.weights, Dict[str, Union[int, float]]) or isoftype(
self.weights, List[Union[int, float]]
)
assert isinstance(self.subsets, dict) == isinstance(self.weights, dict)
def prepare(self):
super().prepare()
self.named_weights = (
{i: float(self.weights[i]) for i in range(len(self.weights))}
if isinstance(self.weights, list)
else {k: float(v) for (k, v) in self.weights.items()}
)
def fusion_generator(self, split) -> Generator:
iterators = {}
for origin_name, origin in self.named_subsets.items():
multi_stream = origin()
if split not in multi_stream:
continue
iterators[origin_name] = iter(multi_stream[split])
total_examples = 0
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
while (
self.max_total_samples is None or total_examples < self.max_total_samples
) and len(iterators) > 0:
population = list(iterators.keys())
origin_name = random_generator.choices(
population=population,
weights=[self.named_weights[name] for name in population],
)[0]
iterator = iterators[origin_name]
try:
instance = next(iterator)
if isinstance(origin_name, str):
if "subset" not in instance:
instance["subset"] = []
instance["subset"].insert(0, origin_name)
total_examples += 1
yield instance
except StopIteration:
iterators.pop(origin_name)
except Exception as e:
raise RuntimeError(f"Exception in subset: {origin_name}") from e
|