Spaces:
Runtime error
Runtime error
Commit
·
6ba247f
1
Parent(s):
e748bc2
added class count function
Browse files- preprocessing/dataset.py +25 -4
preprocessing/dataset.py
CHANGED
|
@@ -2,7 +2,7 @@ import importlib
|
|
| 2 |
import os
|
| 3 |
from typing import Any
|
| 4 |
import torch
|
| 5 |
-
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
import torchaudio as ta
|
|
@@ -278,6 +278,7 @@ class DanceDataModule(pl.LightningDataModule):
|
|
| 278 |
target_classes: list[str] = None,
|
| 279 |
batch_size: int = 64,
|
| 280 |
num_workers=10,
|
|
|
|
| 281 |
):
|
| 282 |
super().__init__()
|
| 283 |
self.val_proportion = val_proportion
|
|
@@ -286,6 +287,10 @@ class DanceDataModule(pl.LightningDataModule):
|
|
| 286 |
self.target_classes = target_classes
|
| 287 |
self.batch_size = batch_size
|
| 288 |
self.num_workers = num_workers
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
self.dataset = dataset
|
| 290 |
|
| 291 |
def setup(self, stage: str):
|
|
@@ -317,9 +322,10 @@ class DanceDataModule(pl.LightningDataModule):
|
|
| 317 |
)
|
| 318 |
|
| 319 |
def get_label_weights(self):
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
|
|
|
| 323 |
return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
|
| 324 |
|
| 325 |
|
|
@@ -349,3 +355,18 @@ def get_datasets(dataset_config: dict, feature_extractor) -> Dataset:
|
|
| 349 |
ProvidedDataset = getattr(module, class_name)
|
| 350 |
datasets.append(ProvidedDataset(**kwargs))
|
| 351 |
return PipelinedDataset(ConcatDataset(datasets), feature_extractor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
from typing import Any
|
| 4 |
import torch
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, Subset
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
import torchaudio as ta
|
|
|
|
| 278 |
target_classes: list[str] = None,
|
| 279 |
batch_size: int = 64,
|
| 280 |
num_workers=10,
|
| 281 |
+
data_subset=None,
|
| 282 |
):
|
| 283 |
super().__init__()
|
| 284 |
self.val_proportion = val_proportion
|
|
|
|
| 287 |
self.target_classes = target_classes
|
| 288 |
self.batch_size = batch_size
|
| 289 |
self.num_workers = num_workers
|
| 290 |
+
|
| 291 |
+
if data_subset is not None and float(data_subset) != 1.0:
|
| 292 |
+
dataset, _ = random_split(dataset, [data_subset, 1 - data_subset])
|
| 293 |
+
|
| 294 |
self.dataset = dataset
|
| 295 |
|
| 296 |
def setup(self, stage: str):
|
|
|
|
| 322 |
)
|
| 323 |
|
| 324 |
def get_label_weights(self):
|
| 325 |
+
dataset = (
|
| 326 |
+
self.dataset.dataset if isinstance(self.dataset, Subset) else self.dataset
|
| 327 |
+
)
|
| 328 |
+
weights = [ds.song_dataset.get_label_weights() for ds in dataset._data.datasets]
|
| 329 |
return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
|
| 330 |
|
| 331 |
|
|
|
|
| 355 |
ProvidedDataset = getattr(module, class_name)
|
| 356 |
datasets.append(ProvidedDataset(**kwargs))
|
| 357 |
return PipelinedDataset(ConcatDataset(datasets), feature_extractor)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def get_class_counts(config: dict):
|
| 361 |
+
# TODO: Figure out why music4dance has fractional labels
|
| 362 |
+
dataset = get_datasets(config["datasets"], lambda x: x)
|
| 363 |
+
counts = sum(
|
| 364 |
+
np.sum(
|
| 365 |
+
np.arange(len(config["dance_ids"]))
|
| 366 |
+
== np.expand_dims(ds.song_dataset.dance_labels.argmax(1), 1),
|
| 367 |
+
axis=0,
|
| 368 |
+
)
|
| 369 |
+
for ds in dataset._data.datasets
|
| 370 |
+
)
|
| 371 |
+
labels = sorted(config["dance_ids"])
|
| 372 |
+
return dict(zip(labels, counts))
|