Last commit not found
import copy | |
from abc import abstractmethod | |
from typing import Generator, List, Optional | |
from .dataclass import NonPositionalField | |
from .operator import SourceOperator | |
from .random_utils import new_random_generator | |
from .stream import MultiStream, Stream | |
class BaseFusion(SourceOperator): | |
"""BaseFusion operator that combines multiple streams into one. | |
Args: | |
include_splits: List of splits to include. If None, all splits are included. | |
""" | |
origins: List[SourceOperator] | |
include_splits: Optional[List[str]] = NonPositionalField(default=None) | |
def fusion_generator(self, split) -> Generator: | |
pass | |
def splits(self) -> Generator: | |
splits = [] | |
for origin in self.origins: | |
for s in origin().keys(): | |
if s not in splits: | |
if self.include_splits is None or s in self.include_splits: | |
splits.append(s) | |
return splits | |
def process( | |
self, | |
) -> MultiStream: | |
result = {} | |
for split in self.splits(): | |
result[split] = Stream(self.fusion_generator, gen_kwargs={"split": split}) | |
return MultiStream(result) | |
class FixedFusion(BaseFusion): | |
"""FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task. | |
Args: | |
orgins: List of SourceOperator objects. | |
examples_per_task: Number of examples per task. If None, all examples are returned. | |
splits: List of splits to include. If None, all splits are included. | |
""" | |
max_instances_per_origin: Optional[int] = None | |
def fusion_generator(self, split) -> Generator: | |
for origin in self.origins: | |
iterator = iter(origin()[split]) | |
if self.max_instances_per_origin is not None: | |
for _ in range(self.max_instances_per_origin): | |
try: | |
yield next(iterator) | |
except StopIteration: | |
break | |
else: | |
yield from iterator | |
class WeightedFusion(BaseFusion): | |
"""Fusion operator that combines multiple streams based. | |
Args: | |
orgins: List of SourceOperator objects. | |
weights: List of weights for each origin. | |
max_total_examples: Total number of examples to return. If None, all examples are returned. | |
""" | |
origins: List[SourceOperator] = None | |
weights: List[float] = None | |
max_total_examples: int = None | |
def verify(self): | |
super().verify() | |
assert self.origins is not None, "origins must be specified" | |
assert self.weights is not None, "weights must be specified" | |
assert len(self.origins) == len( | |
self.weights | |
), "origins and weights must have the same length" | |
def fusion_generator(self, split) -> Generator: | |
weights = copy.deepcopy(self.weights) | |
iterators = [iter(origin()[split]) for origin in self.origins] | |
total_examples = 0 | |
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split) | |
while ( | |
self.max_total_examples is None or total_examples <= self.max_total_examples | |
) and len(iterators) > 0: | |
iterator = random_generator.choices(population=iterators, weights=weights)[ | |
0 | |
] | |
try: | |
yield next(iterator) | |
total_examples += 1 | |
except StopIteration: | |
index = iterators.index(iterator) | |
iterators.pop(index) | |
weights.pop(index) | |