thunnai commited on
Commit
169a7c1
·
1 Parent(s): 4cd5da8

port to support hf

Browse files
.github/workflows/pre-commit.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Pre-commit
2
+ on:
3
+ pull_request:
4
+ push:
5
+ branches: [main]
6
+ jobs:
7
+ pre-commit:
8
+ runs-on: ubuntu-latest
9
+ steps:
10
+ - uses: actions/checkout@v3
11
+ - uses: actions/setup-python@v4
12
+ with:
13
+ python-version: "3.10"
14
+ - uses: pre-commit/[email protected]
.github/workflows/run-pytest.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PyTest
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ run-pytest:
11
+ name: python
12
+ runs-on: ubuntu-latest
13
+
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+
17
+ - name: Install uv
18
+ uses: astral-sh/setup-uv@v5
19
+
20
+ - name: Install the project
21
+ run: uv sync --all-extras --dev
22
+
23
+ - name: Run tests
24
+ run: uv run pytest test/
.gitignore CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
@@ -161,4 +165,4 @@ cython_debug/
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
  # .idea/
163
  .vscode/
164
- .ruff_cache/
 
1
+ .gradio/
2
+ *.pt
3
+ *.pth
4
+
5
  # Byte-compiled / optimized / DLL files
6
  __pycache__/
7
  *.py[cod]
 
165
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166
  # .idea/
167
  .vscode/
168
+ .ruff_cache/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/google/yamlfmt
3
+ rev: v0.16.0
4
+ hooks:
5
+ - id: yamlfmt
6
+ - repo: https://github.com/gitleaks/gitleaks
7
+ rev: v8.23.3
8
+ hooks:
9
+ - id: gitleaks
10
+ - repo: https://github.com/astral-sh/uv-pre-commit
11
+ # uv version.
12
+ rev: 0.5.30
13
+ hooks:
14
+ - id: uv-lock
15
+ - repo: https://github.com/astral-sh/ruff-pre-commit
16
+ rev: v0.9.6
17
+ hooks:
18
+ - id: ruff
19
+ types_or: [python, pyi]
20
+ args: [--fix]
21
+ - id: ruff-format
22
+ types_or: [python, pyi]
examples/example.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"path": "sample_audio/libritts_spk-84.wav"}
examples/predict_from_jsonl.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from audiobox_aesthetics.inference import AudioBoxAesthetics, AudioFileList
2
+
3
+ model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
4
+ model.eval()
5
+
6
+
7
+ audio_file_list = AudioFileList.from_jsonl("examples/example.jsonl")
8
+ predictions = model.predict_from_files(audio_file_list)
9
+ print(predictions)
examples/predict_single_file.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from audiobox_aesthetics.inference import AudioBoxAesthetics
2
+
3
+ model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
4
+ model.eval()
5
+
6
+ wav = model.load_audio("sample_audio/libritts_spk-84.wav")
7
+ predictions = model.predict_from_wavs(wav)
8
+ print(predictions)
pyproject.toml CHANGED
@@ -6,23 +6,21 @@ build-backend = "setuptools.build_meta"
6
  name = "audiobox_aesthetics"
7
  version = "0.0.1"
8
  authors = [
9
- {name="Andros Tjandra", email="[email protected]"},
10
- {name="Yi-Chiao Wu"},
11
- {name="Baishan Guo"},
12
- {name="John Hoffman"},
13
- {name="Brian Ellis"},
14
- {name="Apoorv Vyas"},
15
- {name="Bowen Shi"},
16
- {name="Sanyuan Chen"},
17
- {name="Matt Le"},
18
- {name="Nick Zacharov"},
19
- {name="Carleigh Wood"},
20
- {name="Ann Lee"},
21
- {name="Wei-ning Hsu"}
22
- ]
23
- maintainers = [
24
- {name="Andros Tjandra", email="[email protected]"}
25
  ]
 
26
  description = "Unified automatic quality assessment for speech, music, and sound."
27
  requires-python = ">=3.9"
28
  classifiers = [
@@ -30,14 +28,17 @@ classifiers = [
30
  "Operating System :: OS Independent",
31
  ]
32
  readme = "README.md"
33
- license = {file = "LICENSE"}
34
 
35
  dependencies = [
36
- "numpy",
37
- "torch>=2.2.0",
38
- "torchaudio",
39
- "tqdm",
40
- "submitit"
 
 
 
41
  ]
42
 
43
  [project.scripts]
@@ -47,4 +48,5 @@ audio-aes = "audiobox_aesthetics.cli:app"
47
  Homepage = "https://github.com/facebookresearch/audiobox-aesthetics"
48
  Issues = "https://github.com/facebookresearch/audiobox-aesthetics/issues"
49
 
50
-
 
 
6
  name = "audiobox_aesthetics"
7
  version = "0.0.1"
8
  authors = [
9
+ { name = "Andros Tjandra", email = "[email protected]" },
10
+ { name = "Yi-Chiao Wu" },
11
+ { name = "Baishan Guo" },
12
+ { name = "John Hoffman" },
13
+ { name = "Brian Ellis" },
14
+ { name = "Apoorv Vyas" },
15
+ { name = "Bowen Shi" },
16
+ { name = "Sanyuan Chen" },
17
+ { name = "Matt Le" },
18
+ { name = "Nick Zacharov" },
19
+ { name = "Carleigh Wood" },
20
+ { name = "Ann Lee" },
21
+ { name = "Wei-ning Hsu" },
 
 
 
22
  ]
23
+ maintainers = [{ name = "Andros Tjandra", email = "[email protected]" }]
24
  description = "Unified automatic quality assessment for speech, music, and sound."
25
  requires-python = ">=3.9"
26
  classifiers = [
 
28
  "Operating System :: OS Independent",
29
  ]
30
  readme = "README.md"
31
+ license = { file = "LICENSE" }
32
 
33
  dependencies = [
34
+ "numpy",
35
+ "torch>=2.2.0",
36
+ "torchaudio",
37
+ "tqdm",
38
+ "submitit",
39
+ "huggingface-hub>=0.28.1",
40
+ "pydantic>=2.10.6",
41
+ "safetensors>=0.5.2",
42
  ]
43
 
44
  [project.scripts]
 
48
  Homepage = "https://github.com/facebookresearch/audiobox-aesthetics"
49
  Issues = "https://github.com/facebookresearch/audiobox-aesthetics/issues"
50
 
51
+ [dependency-groups]
52
+ dev = ["gradio>=4.44.1", "ipykernel>=6.29.5", "pytest>=8.3.4"]
sample_audio/libritts_spk-3170.wav ADDED
Binary file (292 kB). View file
 
sample_audio/libritts_spk-84.wav ADDED
Binary file (287 kB). View file
 
sample_audio/test.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"path": "sample_audio/libritts_spk-84.wav"}
src/audiobox_aesthetics/cli.py CHANGED
@@ -14,7 +14,7 @@ import requests
14
 
15
  import submitit
16
  from tqdm import tqdm
17
- from .infer import load_dataset, main_predict
18
 
19
  logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
20
 
 
14
 
15
  import submitit
16
  from tqdm import tqdm
17
+ from audiobox_aesthetics.infer import load_dataset, main_predict
18
 
19
  logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
20
 
src/audiobox_aesthetics/demo.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from audiobox_aesthetics.inference import (
3
+ AudioBoxAesthetics,
4
+ AudioFile,
5
+ AXIS_NAME_LOOKUP,
6
+ )
7
+
8
+ # Load the pre-trained model
9
+ model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
10
+ model.eval()
11
+
12
+
13
+ def predict_aesthetics(audio_file):
14
+ # Create an AudioFile instance
15
+ audio_file_instance = AudioFile(path=audio_file)
16
+
17
+ # Predict using the model
18
+ predictions = model.predict_from_files(audio_file_instance)
19
+
20
+ single_prediction = predictions[0]
21
+
22
+ data_view = [
23
+ [AXIS_NAME_LOOKUP[key], value] for key, value in single_prediction.items()
24
+ ]
25
+
26
+ return single_prediction, data_view
27
+
28
+
29
+ def create_demo():
30
+ # Create a Gradio Blocks interface
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("# AudioBox Aesthetics Prediction")
33
+ with gr.Group():
34
+ gr.Markdown("""Upload an audio file to predict its aesthetic scores.
35
+
36
+ This demo uses the AudioBox Aesthetics model to predict aesthetic scores for audio files along 4 axes:
37
+ - Content Enjoyment (CE)
38
+ - Content Usefulness (CU)
39
+ - Production Complexity (PC)
40
+ - Production Quality (PQ)
41
+
42
+ Scores range from 0 to 10.
43
+
44
+ For more details, see the [paper](https://arxiv.org/abs/2502.05139) or [code](https://github.com/facebookresearch/audiobox-aesthetics/tree/main).
45
+ """)
46
+
47
+ with gr.Row():
48
+ with gr.Group():
49
+ with gr.Column():
50
+ audio_input = gr.Audio(
51
+ sources="upload", type="filepath", label="Upload Audio"
52
+ )
53
+ submit_button = gr.Button("Predict", variant="primary")
54
+ with gr.Group():
55
+ with gr.Column():
56
+ output_data = gr.Dataframe(
57
+ headers=["Axes name", "Score"],
58
+ datatype=["str", "number"],
59
+ label="Aesthetic Scorest",
60
+ )
61
+ output_text = gr.Textbox(label="Raw prediction", interactive=False)
62
+
63
+ submit_button.click(
64
+ predict_aesthetics,
65
+ inputs=audio_input,
66
+ outputs=[output_text, output_data],
67
+ )
68
+
69
+ # Add examples
70
+ gr.Examples(
71
+ examples=[
72
+ "sample_audio/libritts_spk-84.wav",
73
+ "sample_audio/libritts_spk-3170.wav",
74
+ ],
75
+ inputs=audio_input,
76
+ outputs=[output_text, output_data],
77
+ fn=predict_aesthetics,
78
+ cache_examples=True,
79
+ )
80
+
81
+ return demo
82
+
83
+
84
+ if __name__ == "__main__":
85
+ demo = create_demo()
86
+ demo.launch()
src/audiobox_aesthetics/export_model_to_hf.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ import argparse
4
+ import torch
5
+
6
+ from audiobox_aesthetics.inference import AudioBoxAesthetics
7
+
8
+ if __name__ == "__main__":
9
+ # Set up argument parser
10
+ parser = argparse.ArgumentParser(
11
+ description="Download and test AudioBox Aesthetics model"
12
+ )
13
+ parser.add_argument(
14
+ "--checkpoint-url",
15
+ default="https://dl.fbaipublicfiles.com/audiobox-aesthetics/checkpoint.pt",
16
+ help="URL for the base checkpoint",
17
+ )
18
+ parser.add_argument(
19
+ "--model-name",
20
+ default="audiobox-aesthetics",
21
+ help="Name to save/load the pretrained model",
22
+ )
23
+ parser.add_argument(
24
+ "--push-to-hub",
25
+ action="store_true",
26
+ help="Push the model to the Hugging Face Hub",
27
+ )
28
+ args = parser.parse_args()
29
+
30
+ checkpoint_local_path = "base_checkpoint.pth"
31
+
32
+ if not os.path.exists(checkpoint_local_path):
33
+ print("Downloading base checkpoint")
34
+ response = requests.get(args.checkpoint_url)
35
+ with open(checkpoint_local_path, "wb") as f:
36
+ f.write(response.content)
37
+
38
+ # get model config from the base checkpoint
39
+ checkpoint = torch.load(
40
+ checkpoint_local_path, map_location="cpu", weights_only=True
41
+ )
42
+ model_cfg = checkpoint["model_cfg"]
43
+
44
+ # extract normalization params from the base checkpoint
45
+ target_transform = checkpoint["target_transform"]
46
+
47
+ target_transform = {
48
+ axis: {
49
+ "mean": checkpoint["target_transform"][axis]["mean"],
50
+ "std": checkpoint["target_transform"][axis]["std"],
51
+ }
52
+ for axis in target_transform.keys()
53
+ }
54
+
55
+ model = AudioBoxAesthetics(
56
+ sample_rate=16_000, target_transform=target_transform, **model_cfg
57
+ )
58
+
59
+ model._load_base_checkpoint(checkpoint_local_path)
60
+ print("✅ Loaded model from base checkpoint")
61
+
62
+ model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)
63
+ print(f"✅ Saved model to {args.model_name}")
64
+ if args.push_to_hub:
65
+ model.push_to_hub(args.model_name)
66
+ print(f"✅ Pushed model to Hub under {args.model_name}")
67
+
68
+ # test load from pretrained
69
+ model = AudioBoxAesthetics.from_pretrained(args.model_name)
70
+ model.eval()
71
+ print(f"✅ Loaded model from pretrained {args.model_name}")
72
+
73
+ # test inference
74
+ wav = model.load_audio("sample_audio/libritts_spk-84.wav")
75
+ predictions = model.predict_from_wavs(wav)
76
+ print(predictions)
77
+ print("✅ Inference test passed")
src/audiobox_aesthetics/infer.py CHANGED
@@ -14,7 +14,7 @@ import torch
14
  import torchaudio
15
  import torch.nn.functional as F
16
 
17
- from .model.aes_wavlm import Normalize, WavlmAudioEncoderMultiOutput
18
 
19
  Batch = Dict[str, Any]
20
 
@@ -113,6 +113,8 @@ class AesWavlmPredictorMultiOutput:
113
  "bf16": torch.bfloat16,
114
  }.get(self.precision)
115
 
 
 
116
  self.target_transform = {
117
  axis: Normalize(
118
  mean=ckpt["target_transform"][axis]["mean"],
@@ -205,8 +207,8 @@ def main_predict(input_file, ckpt, batch_size=10):
205
  for ii in tqdm(range(0, len(metadata), batch_size)):
206
  output = predictor.forward(metadata[ii : ii + batch_size])
207
  outputs.extend(output)
208
- assert len(outputs) == len(
209
- metadata
210
- ), f"Output {len(outputs)} != input {len(metadata)} length"
211
 
212
  return outputs
 
14
  import torchaudio
15
  import torch.nn.functional as F
16
 
17
+ from audiobox_aesthetics.model.aes_wavlm import Normalize, WavlmAudioEncoderMultiOutput
18
 
19
  Batch = Dict[str, Any]
20
 
 
113
  "bf16": torch.bfloat16,
114
  }.get(self.precision)
115
 
116
+ print("using precision", self.precision)
117
+
118
  self.target_transform = {
119
  axis: Normalize(
120
  mean=ckpt["target_transform"][axis]["mean"],
 
207
  for ii in tqdm(range(0, len(metadata), batch_size)):
208
  output = predictor.forward(metadata[ii : ii + batch_size])
209
  outputs.extend(output)
210
+ assert len(outputs) == len(metadata), (
211
+ f"Output {len(outputs)} != input {len(metadata)} length"
212
+ )
213
 
214
  return outputs
src/audiobox_aesthetics/inference.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+
7
+ from audiobox_aesthetics.model.aes_wavlm import Normalize, WavlmAudioEncoderMultiOutput
8
+ from audiobox_aesthetics.infer import make_inference_batch
9
+
10
+ from pydantic import BaseModel
11
+ import torchaudio
12
+
13
+ from pydantic import BaseModel, Field
14
+ from typing import Optional, List
15
+ import json
16
+
17
+ AXIS_NAME_LOOKUP = {
18
+ "CE": "Content Enjoyment",
19
+ "CU": "Content Usefulness",
20
+ "PC": "Production Complexity",
21
+ "PQ": "Production Quality",
22
+ }
23
+
24
+
25
+ class AudioFile(BaseModel):
26
+ """
27
+ Audio file to be processed
28
+ """
29
+
30
+ path: str
31
+ start_time: Optional[float] = Field(None, description="Start time in seconds")
32
+ end_time: Optional[float] = Field(None, description="End time in seconds")
33
+
34
+
35
+ class AudioFileList(BaseModel):
36
+ """
37
+ List of audio files to be processed
38
+ """
39
+
40
+ files: List[AudioFile]
41
+
42
+ @classmethod
43
+ def from_jsonl(cls, filename: str) -> "AudioFileList":
44
+ audio_files = []
45
+ with open(filename, "r") as f:
46
+ for line in f:
47
+ data = json.loads(line.strip())
48
+ audio_file = AudioFile(**data)
49
+ audio_files.append(audio_file)
50
+ return cls(files=audio_files)
51
+
52
+
53
+ # model
54
+
55
+
56
+ class AudioBoxAesthetics(
57
+ nn.Module,
58
+ PyTorchModelHubMixin,
59
+ library_name="audiobox-aesthetics",
60
+ repo_url="https://github.com/facebookresearch/audiobox-aesthetics",
61
+ ):
62
+ def __init__(
63
+ self,
64
+ proj_num_layer: int = 1,
65
+ proj_ln: bool = False,
66
+ proj_act_fn: str = "gelu",
67
+ proj_dropout: float = 0.0,
68
+ nth_layer: int = 13,
69
+ use_weighted_layer_sum: bool = True,
70
+ precision: str = "32",
71
+ normalize_embed: bool = True,
72
+ output_dim: int = 1,
73
+ target_transform: dict = None,
74
+ sample_rate: int = 16_000,
75
+ ):
76
+ super().__init__()
77
+ self.sample_rate = sample_rate
78
+ self.encoder = WavlmAudioEncoderMultiOutput(
79
+ proj_num_layer=proj_num_layer,
80
+ proj_ln=proj_ln,
81
+ proj_act_fn=proj_act_fn,
82
+ proj_dropout=proj_dropout,
83
+ nth_layer=nth_layer,
84
+ use_weighted_layer_sum=use_weighted_layer_sum,
85
+ precision=precision,
86
+ normalize_embed=normalize_embed,
87
+ output_dim=output_dim,
88
+ )
89
+ self.target_transform = {
90
+ axis: Normalize(
91
+ mean=target_transform[axis]["mean"],
92
+ std=target_transform[axis]["std"],
93
+ )
94
+ for axis in target_transform.keys()
95
+ }
96
+
97
+ def _load_base_checkpoint(self, checkpoint_pth: str):
98
+ with open(checkpoint_pth, "rb") as fin:
99
+ ckpt = torch.load(fin, map_location="cpu", weights_only=True)
100
+ state_dict = {
101
+ re.sub("^model.", "", k): v for (k, v) in ckpt["state_dict"].items()
102
+ }
103
+
104
+ self.encoder.load_state_dict(state_dict)
105
+
106
+ def forward(self, batch, inference_mode: bool = True):
107
+ if inference_mode:
108
+ with torch.inference_mode():
109
+ result = self.encoder(batch)
110
+ else:
111
+ result = self.encoder(batch)
112
+ return result
113
+
114
+ def _process_single_audio(self, wav: torch.Tensor, sample_rate: int):
115
+ """
116
+ Process a single audio file to the target sample rate and return a tensor of shape (1, 1, T)
117
+ """
118
+ target_sample_rate = self.sample_rate
119
+ wav = torchaudio.functional.resample(wav, sample_rate, target_sample_rate)
120
+
121
+ # convert to mono
122
+ if wav.shape[0] > 1:
123
+ wav = wav.mean(dim=0, keepdim=True)
124
+ return wav, target_sample_rate
125
+
126
+ def load_audio(self, path: str, start_time: float = None, end_time: float = None):
127
+ """
128
+ Load an audio file form path
129
+
130
+ Args:
131
+ path: str - path to the audio file
132
+ start_time: float - start time in seconds
133
+ end_time: float - end time in seconds
134
+ Returns:
135
+ wav: torch.Tensor - audio tensor of shape (1, 1, T)
136
+ """
137
+ wav, sample_rate = torchaudio.load(path)
138
+ if start_time is not None and end_time is not None:
139
+ if start_time and end_time:
140
+ wav = wav[
141
+ :, int(start_time * sample_rate) : int(end_time * sample_rate)
142
+ ]
143
+ elif start_time:
144
+ wav = wav[:, int(start_time * sample_rate) :]
145
+ elif end_time:
146
+ wav = wav[:, : int(end_time * sample_rate)]
147
+
148
+ wav, _sr = self._process_single_audio(wav, sample_rate)
149
+
150
+ return wav
151
+
152
+ def predict_from_files(
153
+ self, audio_file_list: AudioFileList | AudioFile
154
+ ) -> List[dict]:
155
+ """
156
+ Predict the aesthetic score for a list of audio files
157
+ """
158
+ if isinstance(audio_file_list, AudioFile):
159
+ audio_file_list = AudioFileList(files=[audio_file_list])
160
+
161
+ wavs = [
162
+ self.load_audio(file.path, file.start_time, file.end_time)
163
+ for file in audio_file_list.files
164
+ ]
165
+
166
+ return self.predict_from_wavs(wavs)
167
+
168
+ def predict_from_wavs(self, wavs: List[torch.Tensor] | torch.Tensor):
169
+ """
170
+ Predict the aesthetic score for a single audio file
171
+
172
+ Args:
173
+ wavs: List[torch.Tensor] - list of audio tensors of shape (1, 1, T) - must be at the sample rate of the model
174
+ Returns:
175
+ preds: List[dict] - list of dictionaries containing the aesthetic scores for each axis
176
+ """
177
+
178
+ if isinstance(wavs, torch.Tensor):
179
+ wavs = [wavs]
180
+
181
+ n_wavs = len(wavs)
182
+
183
+ wavs, masks, weights, bids = make_inference_batch(
184
+ wavs,
185
+ 10,
186
+ 10,
187
+ sample_rate=self.sample_rate,
188
+ )
189
+
190
+ # stack wavs, masks, weights, bids
191
+ wavs = torch.stack(wavs)
192
+ masks = torch.stack(masks)
193
+ weights = torch.tensor(weights)
194
+ bids = torch.tensor(bids)
195
+
196
+ if not wavs.shape[0] == masks.shape[0] == weights.shape[0] == bids.shape[0]:
197
+ raise ValueError("Batch size mismatch")
198
+
199
+ preds_all = self.forward({"wav": wavs, "mask": masks})
200
+ all_result = {}
201
+
202
+ # predict scores across all axis
203
+ for axis in self.target_transform.keys():
204
+ preds = self.target_transform[axis].inverse(preds_all[axis])
205
+ weighted_preds = []
206
+ for bii in range(n_wavs):
207
+ weights_bii = weights[bids == bii]
208
+ weighted_preds.append(
209
+ (
210
+ (preds[bids == bii] * weights_bii).sum() / weights_bii.sum()
211
+ ).item()
212
+ )
213
+ all_result[axis] = weighted_preds
214
+ # re-arrenge result
215
+ preds = [dict(zip(all_result.keys(), vv)) for vv in zip(*all_result.values())]
216
+
217
+ return preds
src/audiobox_aesthetics/model/aes_wavlm.py CHANGED
@@ -9,8 +9,8 @@ import sys
9
  from torch import nn
10
  import torch
11
 
12
- from .utils import create_mlp_block
13
- from .wavlm import WavLM, WavLMConfig
14
 
15
 
16
  DEFAULT_AUDIO_CFG = WavLMConfig(
 
9
  from torch import nn
10
  import torch
11
 
12
+ from audiobox_aesthetics.model.utils import create_mlp_block
13
+ from audiobox_aesthetics.model.wavlm import WavLM, WavLMConfig
14
 
15
 
16
  DEFAULT_AUDIO_CFG = WavLMConfig(
src/audiobox_aesthetics/model/wavlm.py CHANGED
@@ -244,17 +244,17 @@ def quant_noise(module, p, block_size):
244
 
245
  # 2D matrix
246
  if not is_conv:
247
- assert (
248
- module.weight.size(1) % block_size == 0
249
- ), "Input features must be a multiple of block sizes"
250
 
251
  # 4D matrix
252
  else:
253
  # 1x1 convolutions
254
  if module.kernel_size == (1, 1):
255
- assert (
256
- module.in_channels % block_size == 0
257
- ), "Input channels must be a multiple of block sizes"
258
  # regular convolutions
259
  else:
260
  k = module.kernel_size[0] * module.kernel_size[1]
@@ -356,16 +356,16 @@ class MultiheadAttention(nn.Module):
356
  self.head_dim = embed_dim // num_heads
357
  self.q_head_dim = self.head_dim
358
  self.k_head_dim = self.head_dim
359
- assert (
360
- self.head_dim * num_heads == self.embed_dim
361
- ), "embed_dim must be divisible by num_heads"
362
  self.scaling = self.head_dim**-0.5
363
 
364
  self.self_attention = self_attention
365
  self.encoder_decoder_attention = encoder_decoder_attention
366
 
367
  assert not self.self_attention or self.qkv_same_dim, (
368
- "Self-attention requires query, key and " "value to be of the same size"
369
  )
370
 
371
  k_bias = True
@@ -1255,9 +1255,9 @@ class ConvFeatureExtractionModel(nn.Module):
1255
  nn.init.kaiming_normal_(conv.weight)
1256
  return conv
1257
 
1258
- assert (
1259
- is_layer_norm and is_group_norm
1260
- ) is False, "layer norm and group norm are exclusive"
1261
 
1262
  if is_layer_norm:
1263
  return nn.Sequential(
 
244
 
245
  # 2D matrix
246
  if not is_conv:
247
+ assert module.weight.size(1) % block_size == 0, (
248
+ "Input features must be a multiple of block sizes"
249
+ )
250
 
251
  # 4D matrix
252
  else:
253
  # 1x1 convolutions
254
  if module.kernel_size == (1, 1):
255
+ assert module.in_channels % block_size == 0, (
256
+ "Input channels must be a multiple of block sizes"
257
+ )
258
  # regular convolutions
259
  else:
260
  k = module.kernel_size[0] * module.kernel_size[1]
 
356
  self.head_dim = embed_dim // num_heads
357
  self.q_head_dim = self.head_dim
358
  self.k_head_dim = self.head_dim
359
+ assert self.head_dim * num_heads == self.embed_dim, (
360
+ "embed_dim must be divisible by num_heads"
361
+ )
362
  self.scaling = self.head_dim**-0.5
363
 
364
  self.self_attention = self_attention
365
  self.encoder_decoder_attention = encoder_decoder_attention
366
 
367
  assert not self.self_attention or self.qkv_same_dim, (
368
+ "Self-attention requires query, key and value to be of the same size"
369
  )
370
 
371
  k_bias = True
 
1255
  nn.init.kaiming_normal_(conv.weight)
1256
  return conv
1257
 
1258
+ assert (is_layer_norm and is_group_norm) is False, (
1259
+ "layer norm and group norm are exclusive"
1260
+ )
1261
 
1262
  if is_layer_norm:
1263
  return nn.Sequential(
test/test_inference.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from audiobox_aesthetics.inference import AudioBoxAesthetics, AudioFileList, AudioFile
2
+
3
+ # cached results from running the CLI
4
+ cli_results = {
5
+ "sample_audio/libritts_spk-84.wav": {
6
+ "CE": 6.1027421951293945,
7
+ "CU": 6.3574299812316895,
8
+ "PC": 1.7401179075241089,
9
+ "PQ": 6.733065128326416,
10
+ },
11
+ }
12
+
13
+
14
+ def test_inference():
15
+ audio_path = "sample_audio/libritts_spk-84.wav"
16
+ audio_file = AudioFile(path=audio_path)
17
+ model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
18
+ model.eval()
19
+
20
+ predictions = model.predict_from_files(audio_file)
21
+ single_pred = predictions[0]
22
+
23
+ print(single_pred)
24
+
25
+ assert single_pred["CE"] == cli_results[audio_path]["CE"]
26
+ assert single_pred["CU"] == cli_results[audio_path]["CU"]
27
+ assert single_pred["PC"] == cli_results[audio_path]["PC"]
28
+ assert single_pred["PQ"] == cli_results[audio_path]["PQ"]
29
+
30
+
31
+ def test_inference_load_from_jsonl():
32
+ audio_file_list = AudioFileList.from_jsonl("sample_audio/test.jsonl")
33
+ model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
34
+ model.eval()
35
+
36
+ predictions = model.predict_from_files(audio_file_list)
37
+
38
+ single_pred = predictions[0]
39
+ assert single_pred["CE"] == cli_results[audio_file_list.files[0].path]["CE"]
40
+ assert single_pred["CU"] == cli_results[audio_file_list.files[0].path]["CU"]
41
+ assert single_pred["PC"] == cli_results[audio_file_list.files[0].path]["PC"]
42
+ assert single_pred["PQ"] == cli_results[audio_file_list.files[0].path]["PQ"]
43
+
44
+
45
+ def test_inference_twice_on_same_audio_yields_same_result():
46
+ audio_file = AudioFile(path="sample_audio/libritts_spk-84.wav")
47
+ model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
48
+ model.eval()
49
+
50
+ predictions_a = model.predict_from_files(audio_file)
51
+ predictions_b = model.predict_from_files(audio_file)
52
+
53
+ single_pred_a = predictions_a[0]
54
+ single_pred_b = predictions_b[0]
55
+
56
+ assert single_pred_a["CE"] == single_pred_b["CE"]
57
+ assert single_pred_a["CU"] == single_pred_b["CU"]
58
+ assert single_pred_a["PC"] == single_pred_b["PC"]
59
+ assert single_pred_a["PQ"] == single_pred_b["PQ"]
60
+
61
+
62
+ def test_loading_from_wav():
63
+ audio_path = "sample_audio/libritts_spk-84.wav"
64
+ model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
65
+ model.eval()
66
+
67
+ wav = model.load_audio(audio_path)
68
+ predictions = model.predict_from_wavs(wav)
69
+
70
+ single_pred = predictions[0]
71
+ assert single_pred["CE"] == cli_results[audio_path]["CE"]
72
+ assert single_pred["CU"] == cli_results[audio_path]["CU"]
73
+ assert single_pred["PC"] == cli_results[audio_path]["PC"]
74
+ assert single_pred["PQ"] == cli_results[audio_path]["PQ"]
uv.lock ADDED
The diff for this file is too large to render. See raw diff