File size: 1,979 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from collections import Counter

import pytest
from dotenv import dotenv_values

from s3prl.dataio.corpus.speech_commands import SpeechCommandsV1


def _class_counter(data_dict):
    counter = Counter()
    for data_id, data in data_dict.items():
        counter.update([data["class_name"]])
    return counter


@pytest.mark.corpus
def test_speech_commands():
    env = dotenv_values()
    corpus = SpeechCommandsV1(env["GSC1"], env["GSC1_TEST"])
    all_data = corpus.all_data
    classes = set([value["class_name"] for key, value in all_data.items()])
    assert len(classes) == 12, f"{classes}"

    train, valid, test = corpus.data_split
    train_class_counter = _class_counter(train)
    valid_class_counter = _class_counter(valid)
    test_class_counter = _class_counter(test)

    # These pre-defined numbers are obtained with the old DownstreamExpert
    assert train_class_counter == Counter(
        {
            "_unknown_": 32550,
            "stop": 1885,
            "on": 1864,
            "go": 1861,
            "yes": 1860,
            "no": 1853,
            "right": 1852,
            "up": 1843,
            "down": 1842,
            "left": 1839,
            "off": 1839,
            "_silence_": 6,
        }
    )
    assert valid_class_counter == Counter(
        {
            "_unknown_": 4221,
            "stop": 246,
            "on": 257,
            "go": 260,
            "yes": 261,
            "no": 270,
            "right": 256,
            "up": 260,
            "down": 264,
            "left": 247,
            "off": 256,
            "_silence_": 6,
        }
    )
    assert test_class_counter == Counter(
        {
            "_unknown_": 257,
            "stop": 249,
            "on": 246,
            "go": 251,
            "yes": 256,
            "no": 252,
            "right": 259,
            "up": 272,
            "down": 253,
            "left": 267,
            "off": 262,
            "_silence_": 257,
        }
    )