Spaces:
Runtime error
Runtime error
Commit
·
1c22425
1
Parent(s):
c95adc4
fixed metrics and weights
Browse files- models/utils.py +3 -1
- preprocessing/dataset.py +3 -1
models/utils.py
CHANGED
|
@@ -37,7 +37,9 @@ def calculate_metrics(
|
|
| 37 |
pred, target, threshold=0.5, prefix="", multi_label=True
|
| 38 |
) -> dict[str, torch.Tensor]:
|
| 39 |
target = target.detach().cpu().numpy()
|
| 40 |
-
pred = pred.detach().cpu()
|
|
|
|
|
|
|
| 41 |
params = {
|
| 42 |
"y_true": target if multi_label else target.argmax(1),
|
| 43 |
"y_pred": np.array(pred > threshold, dtype=float)
|
|
|
|
| 37 |
pred, target, threshold=0.5, prefix="", multi_label=True
|
| 38 |
) -> dict[str, torch.Tensor]:
|
| 39 |
target = target.detach().cpu().numpy()
|
| 40 |
+
pred = pred.detach().cpu()
|
| 41 |
+
pred = nn.functional.softmax(pred, dim=1)
|
| 42 |
+
pred = pred.numpy()
|
| 43 |
params = {
|
| 44 |
"y_true": target if multi_label else target.argmax(1),
|
| 45 |
"y_pred": np.array(pred > threshold, dtype=float)
|
preprocessing/dataset.py
CHANGED
|
@@ -80,7 +80,9 @@ class SongDataset(Dataset):
|
|
| 80 |
|
| 81 |
def get_label_weights(self):
|
| 82 |
n_examples, n_classes = self.dance_labels.shape
|
| 83 |
-
|
|
|
|
|
|
|
| 84 |
|
| 85 |
def _backtrace_audio_path(self, index: int) -> str:
|
| 86 |
return self.audio_paths[self._idx2audio_idx(index)]
|
|
|
|
| 80 |
|
| 81 |
def get_label_weights(self):
|
| 82 |
n_examples, n_classes = self.dance_labels.shape
|
| 83 |
+
weights = n_examples / (n_classes * sum(self.dance_labels))
|
| 84 |
+
weights[np.isinf(weights)] = 0.0
|
| 85 |
+
return torch.from_numpy(weights)
|
| 86 |
|
| 87 |
def _backtrace_audio_path(self, index: int) -> str:
|
| 88 |
return self.audio_paths[self._idx2audio_idx(index)]
|