Spaces:
Runtime error
Runtime error
# copy from megatron | |
def get_a_and_b_segments(sample, np_rng): | |
"""Divide sample into a and b segments.""" | |
# Number of sentences in the sample. | |
n_sentences = len(sample) | |
# Make sure we always have two sentences. | |
assert n_sentences > 1, 'make sure each sample has at least two sentences.' | |
# First part: | |
# `a_end` is how many sentences go into the `A`. | |
a_end = 1 | |
if n_sentences >= 3: | |
# Note that randin in numpy is exclusive. | |
a_end = np_rng.randint(1, n_sentences) | |
tokens_a = [] | |
for j in range(a_end): | |
tokens_a.extend(sample[j]) | |
# Second part: | |
tokens_b = [] | |
for j in range(a_end, n_sentences): | |
tokens_b.extend(sample[j]) | |
# Random next: | |
is_next_random = False | |
if np_rng.random() < 0.5: | |
is_next_random = True | |
tokens_a, tokens_b = tokens_b, tokens_a | |
return tokens_a, tokens_b, is_next_random | |