from s3prl.dataio.dataset.frame_label import ( chunk_labels_to_frame_tensor_label, chunking, ) def test_chunking(): chunks = list(chunking(0.0, 8.5, 2.0, 1.0, False)) assert len(chunks) == 7 chunks = list(chunking(1.1, 8.5, 2.0, 1.0, True)) assert len(chunks) == 8 def test_frame_tensor_label(): labels = [ (0, 3.0, 4.1), (1, 1.2, 3.2), ] label = chunk_labels_to_frame_tensor_label(1.5, 4.0, labels, 3, 160) assert label[-1, 0] == 1 assert label[0, 1] == 1