File size: 6,346 Bytes
cbca7b8 b462f85 5852323 300a7be 2d210f5 cc653d8 b462f85 5c545d2 5852323 cbca7b8 b462f85 5c545d2 cbca7b8 b462f85 cbca7b8 5c545d2 b462f85 300a7be 5c545d2 cbca7b8 5c545d2 b462f85 cbca7b8 b462f85 cbca7b8 5852323 5c545d2 cbca7b8 b462f85 cbca7b8 5852323 5c545d2 cbca7b8 b462f85 5c545d2 cbca7b8 b462f85 cbca7b8 5c545d2 b462f85 5c545d2 b462f85 cbca7b8 b462f85 ef1f482 b462f85 5c545d2 5852323 cbca7b8 b462f85 5c545d2 cbca7b8 b462f85 cbca7b8 5c545d2 b462f85 300a7be 5c545d2 cbca7b8 17a636b b462f85 5c545d2 cbca7b8 b462f85 cbca7b8 cc653d8 17a636b b462f85 17a636b b462f85 cbca7b8 b462f85 cbca7b8 b462f85 cbca7b8 b462f85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from abc import abstractmethod
from typing import Dict, Generator, List, Optional, Union
from .dataclass import NonPositionalField
from .operator import SourceOperator
from .random_utils import new_random_generator
from .stream import GeneratorStream, MultiStream
from .type_utils import isoftype
class BaseFusion(SourceOperator):
"""BaseFusion operator that combines multiple multistreams into one.
Args:
origins: a dict of named SourceOperator objects (each to yield a MultiStream) or a list thereof,
each is specified along with its input, so can generate a MultiStream
include_splits: List of splits to include from each input MultiStream.
If None, all splits are included.
"""
origins: Union[List[SourceOperator], Dict[str, SourceOperator]]
include_splits: Optional[List[str]] = NonPositionalField(default=None)
@abstractmethod
def fusion_generator(self, split) -> Generator:
pass
def prepare(self):
assert isoftype(self.origins, Dict[str, SourceOperator]) or isoftype(
self.origins, List[SourceOperator]
)
self.named_origins = (
{i: self.origins[i]() for i in range(len(self.origins))}
if isinstance(self.origins, list)
else {name: origin() for name, origin in self.origins.items()}
)
def splits(self) -> List[str]:
splits = []
for _, origin in self.named_origins.items():
for s in origin.keys():
if s not in splits:
if self.include_splits is None or s in self.include_splits:
splits.append(s)
return splits
def process(
self,
) -> MultiStream:
result = {}
for split in self.splits():
result[split] = GeneratorStream(
self.fusion_generator, gen_kwargs={"split": split}
)
return MultiStream(result)
class FixedFusion(BaseFusion):
"""FixedFusion operator that combines multiple multistreams into one, limiting the number of instances taken from each split of each input multistream.
Args:
origins: Dict of named SourceOperator objects (each to yield a MultiStream), or a list thereof
splits: List of splits (stream_names) to include, over all input multistreams. If None, all splits are included.
max_instances_per_origin_split: Number of instances to take from each input split of each input multistream.
If None, all instances of each split (that is specified in include_splits) are included in the result.
"""
max_instances_per_origin_split: Optional[int] = None
def prepare(self):
super().prepare()
# flake8: noqa: C901
def fusion_generator(self, split) -> Generator:
for origin_name, origin in self.named_origins.items():
if split not in origin:
continue
emitted_from_this_split = 0
for instance in origin[split]:
if (
self.max_instances_per_origin_split is not None
and emitted_from_this_split >= self.max_instances_per_origin_split
):
break
if isinstance(origin_name, str):
# named origins, not anonymous, record in instance
if "group" in instance:
instance["group"] = origin_name + "/" + instance["group"]
else:
instance["group"] = origin_name
emitted_from_this_split += 1
yield instance
class WeightedFusion(BaseFusion):
"""Fusion operator that combines multiple MultiStream-s.
Args:
origins: Dict of named MultiStream objects, or a list thereof
weights: Dict of named weights for each origin, or a list thereof
max_total_examples: Total number of instances to return per returned split.
If None, all instances are returned
"""
origins: Union[Dict[str, MultiStream], List[MultiStream]] = None
weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
max_total_examples: int = None
def verify(self):
super().verify()
assert self.origins is not None, "origins must be specified"
assert self.weights is not None, "weights must be specified"
assert len(self.origins) == len(
self.weights
), "origins and weights must have the same length"
assert isoftype(self.origins, Dict[str, SourceOperator]) or isoftype(
self.origins, List[SourceOperator]
)
assert isoftype(self.weights, Dict[str, Union[int, float]]) or isoftype(
self.weights, List[Union[int, float]]
)
assert isinstance(self.origins, dict) == isinstance(self.weights, dict)
def prepare(self):
super().prepare()
self.named_weights = (
{i: float(self.weights[i]) for i in range(len(self.weights))}
if isinstance(self.weights, list)
else {k: float(v) for (k, v) in self.weights.items()}
)
def fusion_generator(self, split) -> Generator:
iterators = {
named_origin: iter(origin[split])
for named_origin, origin in self.named_origins.items()
}
total_examples = 0
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
while (
self.max_total_examples is None or total_examples < self.max_total_examples
) and len(iterators) > 0:
population = list(iterators.keys())
origin_name = random_generator.choices(
population=population,
weights=[self.named_weights[name] for name in population],
)[0]
iterator = iterators[origin_name]
try:
instance = next(iterator)
if isinstance(origin_name, str):
if "group" in instance:
instance["group"] = origin_name + "/" + instance["group"]
else:
instance["group"] = origin_name
total_examples += 1
yield instance
except StopIteration:
iterators.pop(origin_name)
|