metric / fusion.py
Elron's picture
Upload folder using huggingface_hub
d08fbc6 verified
raw
history blame
6.33 kB
from abc import abstractmethod
from typing import Dict, Generator, List, Optional, Union
from .dataclass import NonPositionalField
from .operator import SourceOperator
from .random_utils import new_random_generator
from .stream import DynamicStream, MultiStream
from .type_utils import isoftype
class BaseFusion(SourceOperator):
"""BaseFusion operator that combines multiple multistreams into one.
Args:
subsets: a dict of named SourceOperator objects (each to yield a MultiStream) or a list thereof,
each is specified along with its input, so can generate a MultiStream
include_splits: List of splits to include from each input MultiStream.
If None, all splits are included.
"""
subsets: Union[List[SourceOperator], Dict[str, SourceOperator]]
include_splits: Optional[List[str]] = NonPositionalField(default=None)
@abstractmethod
def fusion_generator(self, split) -> Generator:
pass
def prepare(self):
assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
self.subsets, List[SourceOperator]
)
self.named_subsets = (
{i: self.subsets[i]() for i in range(len(self.subsets))}
if isinstance(self.subsets, list)
else {name: origin() for name, origin in self.subsets.items()}
)
def splits(self) -> List[str]:
splits = []
for _, origin in self.named_subsets.items():
for s in origin.keys():
if s not in splits:
if self.include_splits is None or s in self.include_splits:
splits.append(s)
return splits
def process(
self,
) -> MultiStream:
result = {}
for split in self.splits():
result[split] = DynamicStream(
self.fusion_generator, gen_kwargs={"split": split}
)
return MultiStream(result)
class FixedFusion(BaseFusion):
"""FixedFusion operator that combines multiple multistreams into one, limiting the number of instances taken from each split of each input multistream.
Args:
subsets: Dict of named SourceOperator objects (each to yield a MultiStream), or a list thereof
splits: List of splits (stream_names) to include, over all input multistreams. If None, all splits are included.
max_instances_per_subset: Number of instances to take from each input split of each input multistream.
If None, all instances of each split (that is specified in include_splits) are included in the result.
"""
max_instances_per_subset: Optional[int] = None
def prepare(self):
super().prepare()
# flake8: noqa: C901
def fusion_generator(self, split) -> Generator:
for origin_name, origin in self.named_subsets.items():
if split not in origin:
continue
emitted_from_this_split = 0
try:
for instance in origin[split]:
if (
self.max_instances_per_subset is not None
and emitted_from_this_split >= self.max_instances_per_subset
):
break
if isinstance(origin_name, str):
if "subset" not in instance:
instance["subset"] = []
instance["subset"].insert(0, origin_name)
emitted_from_this_split += 1
yield instance
except Exception as e:
raise RuntimeError(f"Exception in subset: {origin_name}") from e
class WeightedFusion(BaseFusion):
"""Fusion operator that combines multiple MultiStream-s.
Args:
subsets: Dict of named MultiStream objects, or a list thereof
weights: Dict of named weights for each origin, or a list thereof
max_total_examples: Total number of instances to return per returned split.
If None, all instances are returned
"""
subsets: Union[Dict[str, SourceOperator], List[SourceOperator]] = None
weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
max_total_samples: int = None
def verify(self):
super().verify()
assert self.subsets is not None, "subsets must be specified"
assert self.weights is not None, "weights must be specified"
assert len(self.subsets) == len(
self.weights
), "subsets and weights must have the same length"
assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
self.subsets, List[SourceOperator]
)
assert isoftype(self.weights, Dict[str, Union[int, float]]) or isoftype(
self.weights, List[Union[int, float]]
)
assert isinstance(self.subsets, dict) == isinstance(self.weights, dict)
def prepare(self):
super().prepare()
self.named_weights = (
{i: float(self.weights[i]) for i in range(len(self.weights))}
if isinstance(self.weights, list)
else {k: float(v) for (k, v) in self.weights.items()}
)
def fusion_generator(self, split) -> Generator:
iterators = {
named_origin: iter(origin[split])
for named_origin, origin in self.named_subsets.items()
}
total_examples = 0
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
while (
self.max_total_samples is None or total_examples < self.max_total_samples
) and len(iterators) > 0:
population = list(iterators.keys())
origin_name = random_generator.choices(
population=population,
weights=[self.named_weights[name] for name in population],
)[0]
iterator = iterators[origin_name]
try:
instance = next(iterator)
if isinstance(origin_name, str):
if "subset" not in instance:
instance["subset"] = []
instance["subset"].insert(0, origin_name)
total_examples += 1
yield instance
except StopIteration:
iterators.pop(origin_name)