File size: 1,979 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from collections import Counter
import pytest
from dotenv import dotenv_values
from s3prl.dataio.corpus.speech_commands import SpeechCommandsV1
def _class_counter(data_dict):
counter = Counter()
for data_id, data in data_dict.items():
counter.update([data["class_name"]])
return counter
@pytest.mark.corpus
def test_speech_commands():
env = dotenv_values()
corpus = SpeechCommandsV1(env["GSC1"], env["GSC1_TEST"])
all_data = corpus.all_data
classes = set([value["class_name"] for key, value in all_data.items()])
assert len(classes) == 12, f"{classes}"
train, valid, test = corpus.data_split
train_class_counter = _class_counter(train)
valid_class_counter = _class_counter(valid)
test_class_counter = _class_counter(test)
# These pre-defined numbers are obtained with the old DownstreamExpert
assert train_class_counter == Counter(
{
"_unknown_": 32550,
"stop": 1885,
"on": 1864,
"go": 1861,
"yes": 1860,
"no": 1853,
"right": 1852,
"up": 1843,
"down": 1842,
"left": 1839,
"off": 1839,
"_silence_": 6,
}
)
assert valid_class_counter == Counter(
{
"_unknown_": 4221,
"stop": 246,
"on": 257,
"go": 260,
"yes": 261,
"no": 270,
"right": 256,
"up": 260,
"down": 264,
"left": 247,
"off": 256,
"_silence_": 6,
}
)
assert test_class_counter == Counter(
{
"_unknown_": 257,
"stop": 249,
"on": 246,
"go": 251,
"yes": 256,
"no": 252,
"right": 259,
"up": 272,
"down": 253,
"left": 267,
"off": 262,
"_silence_": 257,
}
)
|