Upload split_utils.py with huggingface_hub
Browse files- split_utils.py +39 -24
split_utils.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
import itertools
|
| 2 |
-
import random
|
| 3 |
import re
|
|
|
|
| 4 |
|
| 5 |
from .generator_utils import ReusableGenerator
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def parse_random_mix_string(input_str):
|
|
@@ -67,7 +69,7 @@ def parse_slices_string(input_str):
|
|
| 67 |
result_dict = {}
|
| 68 |
|
| 69 |
# Split the input string into a list of sources
|
| 70 |
-
sources = re.split("\+", input_str)
|
| 71 |
for source in sources:
|
| 72 |
# If the source has a slice, parse it
|
| 73 |
match = re.fullmatch(r"(\w+)\[(\d*):(\d*)\]", source)
|
|
@@ -119,7 +121,7 @@ def slice_streams(input_streams, mapping):
|
|
| 119 |
the new streams, which consist of parts of the old streams chained together.
|
| 120 |
|
| 121 |
Raises:
|
| 122 |
-
ValueError: If a stream is supposed to be sliced at an index greater than its length.
|
| 123 |
|
| 124 |
Example:
|
| 125 |
>>> old_streams = {"train": [1, 2, 3, 4, 5, 6, 7, 8, 9], "test": [10, 11, 12, 13, 14]}
|
|
@@ -205,15 +207,30 @@ def build_stream_routing(mapping):
|
|
| 205 |
return stream_mapping
|
| 206 |
|
| 207 |
|
| 208 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
for old_stream_name in new_stream_sources:
|
| 210 |
optinal_streams, weights = stream_routing[old_stream_name]
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
yield item
|
| 217 |
|
| 218 |
|
| 219 |
def random_mix_streams(input_streams, mapping):
|
|
@@ -263,20 +280,18 @@ def random_mix_streams(input_streams, mapping):
|
|
| 263 |
# Build stream routing
|
| 264 |
stream_routing = build_stream_routing(mapping)
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
},
|
| 279 |
-
)
|
| 280 |
|
| 281 |
return new_streams
|
| 282 |
|
|
|
|
| 1 |
import itertools
|
|
|
|
| 2 |
import re
|
| 3 |
+
from typing import Dict
|
| 4 |
|
| 5 |
from .generator_utils import ReusableGenerator
|
| 6 |
+
from .random_utils import nested_seed
|
| 7 |
+
from .stream import Stream
|
| 8 |
|
| 9 |
|
| 10 |
def parse_random_mix_string(input_str):
|
|
|
|
| 69 |
result_dict = {}
|
| 70 |
|
| 71 |
# Split the input string into a list of sources
|
| 72 |
+
sources = re.split(r"\+", input_str)
|
| 73 |
for source in sources:
|
| 74 |
# If the source has a slice, parse it
|
| 75 |
match = re.fullmatch(r"(\w+)\[(\d*):(\d*)\]", source)
|
|
|
|
| 121 |
the new streams, which consist of parts of the old streams chained together.
|
| 122 |
|
| 123 |
Raises:
|
| 124 |
+
ValueError: If a stream is supposed to be sliced at an index greater than its length or a negative one.
|
| 125 |
|
| 126 |
Example:
|
| 127 |
>>> old_streams = {"train": [1, 2, 3, 4, 5, 6, 7, 8, 9], "test": [10, 11, 12, 13, 14]}
|
|
|
|
| 207 |
return stream_mapping
|
| 208 |
|
| 209 |
|
| 210 |
+
def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]):
|
| 211 |
+
"""
|
| 212 |
+
Renames the streams
|
| 213 |
+
Args:
|
| 214 |
+
input_streams (dict): A dictionary containing the input streams, where each key is
|
| 215 |
+
the name of the stream and the value is an iterable or generator
|
| 216 |
+
representing the stream.
|
| 217 |
+
|
| 218 |
+
mapping (dict): A dictionary specifying the mapping of old streams to new streams.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
dict: A dictionary containing the generated new streams, where each key is the name
|
| 222 |
+
of the new stream and the value is a generator representing the stream."""
|
| 223 |
+
return {mapping.get(key, key): val for key, val in input_streams.items()}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def random_mix_generator(new_stream_name, new_stream_sources, stream_routing, input_streams):
|
| 227 |
for old_stream_name in new_stream_sources:
|
| 228 |
optinal_streams, weights = stream_routing[old_stream_name]
|
| 229 |
+
with nested_seed(old_stream_name) as rand:
|
| 230 |
+
for item in input_streams[old_stream_name]:
|
| 231 |
+
choice = rand.choices(optinal_streams, weights=weights, k=1)[0]
|
| 232 |
+
if choice == new_stream_name:
|
| 233 |
+
yield item
|
|
|
|
| 234 |
|
| 235 |
|
| 236 |
def random_mix_streams(input_streams, mapping):
|
|
|
|
| 280 |
# Build stream routing
|
| 281 |
stream_routing = build_stream_routing(mapping)
|
| 282 |
|
| 283 |
+
with nested_seed():
|
| 284 |
+
# Create new stream generators
|
| 285 |
+
for new_stream_name, new_stream_sources in mapping.items():
|
| 286 |
+
new_streams[new_stream_name] = ReusableGenerator(
|
| 287 |
+
random_mix_generator,
|
| 288 |
+
gen_kwargs={
|
| 289 |
+
"new_stream_name": new_stream_name,
|
| 290 |
+
"new_stream_sources": new_stream_sources,
|
| 291 |
+
"stream_routing": stream_routing,
|
| 292 |
+
"input_streams": input_streams,
|
| 293 |
+
},
|
| 294 |
+
)
|
|
|
|
|
|
|
| 295 |
|
| 296 |
return new_streams
|
| 297 |
|