File size: 2,131 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import logging
import random
from collections import OrderedDict

from s3prl.dataio.sampler import SortedBucketingSampler, SortedSliceSampler

logger = logging.getLogger(__name__)


def test_sorted_slice_sampler():
    batch_size = 16
    max_length = 16000 * 5
    lengths = [random.randint(16000 * 3, 16000 * 8) for index in range(1000)]

    sampler = SortedSliceSampler(
        lengths,
        batch_size=batch_size,
        max_length=max_length,
    )

    for epoch in range(5):
        sampler.set_epoch(epoch)
        id2length = lengths
        for batch_ids in sampler:
            batch_lengths = [id2length[idx] for idx in batch_ids]
            assert sorted(batch_lengths, reverse=True) == batch_lengths
            if batch_lengths[0] > max_length:
                assert len(batch_lengths) == batch_size // 2

        other_batch_sizes = [
            len(batch)
            for batch in sampler
            if len(batch) not in [batch_size, batch_size // 2]
        ]
        assert len(set(other_batch_sizes)) == len(other_batch_sizes)
        assert len(sampler) == len(lengths)


def test_sorted_bucketing_sampler():
    batch_size = 16
    max_length = 16000 * 5
    lengths = [random.randint(16000 * 3, 16000 * 8) for index in range(1000)]

    sampler = SortedBucketingSampler(
        lengths,
        batch_size=batch_size,
        max_length=max_length,
        shuffle=False,
    )

    for epoch in range(5):
        sampler.set_epoch(epoch)
        id2length = lengths
        for batch_ids in sampler:
            batch_lengths = [id2length[idx] for idx in batch_ids]
            assert sorted(batch_lengths, reverse=True) == batch_lengths
            if batch_lengths[0] > max_length:
                assert len(batch_lengths) == batch_size // 2

        batch_sizes = [len(batch_indices) for batch_indices in sampler]
        other_batch_sizes = [
            batch_size
            for batch_size in batch_sizes
            if batch_size not in [batch_size, batch_size // 2]
        ]
        assert len(other_batch_sizes) <= 1
        assert len(lengths) / 16 < len(sampler) < len(lengths) / 8