Spaces:
Running
Running
port to support hf
Browse files- .github/workflows/pre-commit.yaml +14 -0
- .github/workflows/run-pytest.yaml +24 -0
- .gitignore +5 -1
- .pre-commit-config.yaml +22 -0
- examples/example.jsonl +1 -0
- examples/predict_from_jsonl.py +9 -0
- examples/predict_single_file.py +8 -0
- pyproject.toml +25 -23
- sample_audio/libritts_spk-3170.wav +0 -0
- sample_audio/libritts_spk-84.wav +0 -0
- sample_audio/test.jsonl +1 -0
- src/audiobox_aesthetics/cli.py +1 -1
- src/audiobox_aesthetics/demo.py +86 -0
- src/audiobox_aesthetics/export_model_to_hf.py +77 -0
- src/audiobox_aesthetics/infer.py +6 -4
- src/audiobox_aesthetics/inference.py +217 -0
- src/audiobox_aesthetics/model/aes_wavlm.py +2 -2
- src/audiobox_aesthetics/model/wavlm.py +13 -13
- test/test_inference.py +74 -0
- uv.lock +0 -0
.github/workflows/pre-commit.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Pre-commit
|
2 |
+
on:
|
3 |
+
pull_request:
|
4 |
+
push:
|
5 |
+
branches: [main]
|
6 |
+
jobs:
|
7 |
+
pre-commit:
|
8 |
+
runs-on: ubuntu-latest
|
9 |
+
steps:
|
10 |
+
- uses: actions/checkout@v3
|
11 |
+
- uses: actions/setup-python@v4
|
12 |
+
with:
|
13 |
+
python-version: "3.10"
|
14 |
+
- uses: pre-commit/[email protected]
|
.github/workflows/run-pytest.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PyTest
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [main]
|
6 |
+
pull_request:
|
7 |
+
branches: [main]
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
run-pytest:
|
11 |
+
name: python
|
12 |
+
runs-on: ubuntu-latest
|
13 |
+
|
14 |
+
steps:
|
15 |
+
- uses: actions/checkout@v4
|
16 |
+
|
17 |
+
- name: Install uv
|
18 |
+
uses: astral-sh/setup-uv@v5
|
19 |
+
|
20 |
+
- name: Install the project
|
21 |
+
run: uv sync --all-extras --dev
|
22 |
+
|
23 |
+
- name: Run tests
|
24 |
+
run: uv run pytest test/
|
.gitignore
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
# Byte-compiled / optimized / DLL files
|
2 |
__pycache__/
|
3 |
*.py[cod]
|
@@ -161,4 +165,4 @@ cython_debug/
|
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
# .idea/
|
163 |
.vscode/
|
164 |
-
.ruff_cache/
|
|
|
1 |
+
.gradio/
|
2 |
+
*.pt
|
3 |
+
*.pth
|
4 |
+
|
5 |
# Byte-compiled / optimized / DLL files
|
6 |
__pycache__/
|
7 |
*.py[cod]
|
|
|
165 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
166 |
# .idea/
|
167 |
.vscode/
|
168 |
+
.ruff_cache/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/google/yamlfmt
|
3 |
+
rev: v0.16.0
|
4 |
+
hooks:
|
5 |
+
- id: yamlfmt
|
6 |
+
- repo: https://github.com/gitleaks/gitleaks
|
7 |
+
rev: v8.23.3
|
8 |
+
hooks:
|
9 |
+
- id: gitleaks
|
10 |
+
- repo: https://github.com/astral-sh/uv-pre-commit
|
11 |
+
# uv version.
|
12 |
+
rev: 0.5.30
|
13 |
+
hooks:
|
14 |
+
- id: uv-lock
|
15 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
16 |
+
rev: v0.9.6
|
17 |
+
hooks:
|
18 |
+
- id: ruff
|
19 |
+
types_or: [python, pyi]
|
20 |
+
args: [--fix]
|
21 |
+
- id: ruff-format
|
22 |
+
types_or: [python, pyi]
|
examples/example.jsonl
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"path": "sample_audio/libritts_spk-84.wav"}
|
examples/predict_from_jsonl.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from audiobox_aesthetics.inference import AudioBoxAesthetics, AudioFileList
|
2 |
+
|
3 |
+
model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
|
4 |
+
model.eval()
|
5 |
+
|
6 |
+
|
7 |
+
audio_file_list = AudioFileList.from_jsonl("examples/example.jsonl")
|
8 |
+
predictions = model.predict_from_files(audio_file_list)
|
9 |
+
print(predictions)
|
examples/predict_single_file.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from audiobox_aesthetics.inference import AudioBoxAesthetics
|
2 |
+
|
3 |
+
model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
|
4 |
+
model.eval()
|
5 |
+
|
6 |
+
wav = model.load_audio("sample_audio/libritts_spk-84.wav")
|
7 |
+
predictions = model.predict_from_wavs(wav)
|
8 |
+
print(predictions)
|
pyproject.toml
CHANGED
@@ -6,23 +6,21 @@ build-backend = "setuptools.build_meta"
|
|
6 |
name = "audiobox_aesthetics"
|
7 |
version = "0.0.1"
|
8 |
authors = [
|
9 |
-
{name="Andros Tjandra", email="[email protected]"},
|
10 |
-
{name="Yi-Chiao Wu"},
|
11 |
-
{name="Baishan Guo"},
|
12 |
-
{name="John Hoffman"},
|
13 |
-
{name="Brian Ellis"},
|
14 |
-
{name="Apoorv Vyas"},
|
15 |
-
{name="Bowen Shi"},
|
16 |
-
{name="Sanyuan Chen"},
|
17 |
-
{name="Matt Le"},
|
18 |
-
{name="Nick Zacharov"},
|
19 |
-
{name="Carleigh Wood"},
|
20 |
-
{name="Ann Lee"},
|
21 |
-
{name="Wei-ning Hsu"}
|
22 |
-
]
|
23 |
-
maintainers = [
|
24 |
-
{name="Andros Tjandra", email="[email protected]"}
|
25 |
]
|
|
|
26 |
description = "Unified automatic quality assessment for speech, music, and sound."
|
27 |
requires-python = ">=3.9"
|
28 |
classifiers = [
|
@@ -30,14 +28,17 @@ classifiers = [
|
|
30 |
"Operating System :: OS Independent",
|
31 |
]
|
32 |
readme = "README.md"
|
33 |
-
license = {file = "LICENSE"}
|
34 |
|
35 |
dependencies = [
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
41 |
]
|
42 |
|
43 |
[project.scripts]
|
@@ -47,4 +48,5 @@ audio-aes = "audiobox_aesthetics.cli:app"
|
|
47 |
Homepage = "https://github.com/facebookresearch/audiobox-aesthetics"
|
48 |
Issues = "https://github.com/facebookresearch/audiobox-aesthetics/issues"
|
49 |
|
50 |
-
|
|
|
|
6 |
name = "audiobox_aesthetics"
|
7 |
version = "0.0.1"
|
8 |
authors = [
|
9 |
+
{ name = "Andros Tjandra", email = "[email protected]" },
|
10 |
+
{ name = "Yi-Chiao Wu" },
|
11 |
+
{ name = "Baishan Guo" },
|
12 |
+
{ name = "John Hoffman" },
|
13 |
+
{ name = "Brian Ellis" },
|
14 |
+
{ name = "Apoorv Vyas" },
|
15 |
+
{ name = "Bowen Shi" },
|
16 |
+
{ name = "Sanyuan Chen" },
|
17 |
+
{ name = "Matt Le" },
|
18 |
+
{ name = "Nick Zacharov" },
|
19 |
+
{ name = "Carleigh Wood" },
|
20 |
+
{ name = "Ann Lee" },
|
21 |
+
{ name = "Wei-ning Hsu" },
|
|
|
|
|
|
|
22 |
]
|
23 |
+
maintainers = [{ name = "Andros Tjandra", email = "[email protected]" }]
|
24 |
description = "Unified automatic quality assessment for speech, music, and sound."
|
25 |
requires-python = ">=3.9"
|
26 |
classifiers = [
|
|
|
28 |
"Operating System :: OS Independent",
|
29 |
]
|
30 |
readme = "README.md"
|
31 |
+
license = { file = "LICENSE" }
|
32 |
|
33 |
dependencies = [
|
34 |
+
"numpy",
|
35 |
+
"torch>=2.2.0",
|
36 |
+
"torchaudio",
|
37 |
+
"tqdm",
|
38 |
+
"submitit",
|
39 |
+
"huggingface-hub>=0.28.1",
|
40 |
+
"pydantic>=2.10.6",
|
41 |
+
"safetensors>=0.5.2",
|
42 |
]
|
43 |
|
44 |
[project.scripts]
|
|
|
48 |
Homepage = "https://github.com/facebookresearch/audiobox-aesthetics"
|
49 |
Issues = "https://github.com/facebookresearch/audiobox-aesthetics/issues"
|
50 |
|
51 |
+
[dependency-groups]
|
52 |
+
dev = ["gradio>=4.44.1", "ipykernel>=6.29.5", "pytest>=8.3.4"]
|
sample_audio/libritts_spk-3170.wav
ADDED
Binary file (292 kB). View file
|
|
sample_audio/libritts_spk-84.wav
ADDED
Binary file (287 kB). View file
|
|
sample_audio/test.jsonl
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"path": "sample_audio/libritts_spk-84.wav"}
|
src/audiobox_aesthetics/cli.py
CHANGED
@@ -14,7 +14,7 @@ import requests
|
|
14 |
|
15 |
import submitit
|
16 |
from tqdm import tqdm
|
17 |
-
from .infer import load_dataset, main_predict
|
18 |
|
19 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
20 |
|
|
|
14 |
|
15 |
import submitit
|
16 |
from tqdm import tqdm
|
17 |
+
from audiobox_aesthetics.infer import load_dataset, main_predict
|
18 |
|
19 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
20 |
|
src/audiobox_aesthetics/demo.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from audiobox_aesthetics.inference import (
|
3 |
+
AudioBoxAesthetics,
|
4 |
+
AudioFile,
|
5 |
+
AXIS_NAME_LOOKUP,
|
6 |
+
)
|
7 |
+
|
8 |
+
# Load the pre-trained model
|
9 |
+
model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
|
10 |
+
model.eval()
|
11 |
+
|
12 |
+
|
13 |
+
def predict_aesthetics(audio_file):
|
14 |
+
# Create an AudioFile instance
|
15 |
+
audio_file_instance = AudioFile(path=audio_file)
|
16 |
+
|
17 |
+
# Predict using the model
|
18 |
+
predictions = model.predict_from_files(audio_file_instance)
|
19 |
+
|
20 |
+
single_prediction = predictions[0]
|
21 |
+
|
22 |
+
data_view = [
|
23 |
+
[AXIS_NAME_LOOKUP[key], value] for key, value in single_prediction.items()
|
24 |
+
]
|
25 |
+
|
26 |
+
return single_prediction, data_view
|
27 |
+
|
28 |
+
|
29 |
+
def create_demo():
|
30 |
+
# Create a Gradio Blocks interface
|
31 |
+
with gr.Blocks() as demo:
|
32 |
+
gr.Markdown("# AudioBox Aesthetics Prediction")
|
33 |
+
with gr.Group():
|
34 |
+
gr.Markdown("""Upload an audio file to predict its aesthetic scores.
|
35 |
+
|
36 |
+
This demo uses the AudioBox Aesthetics model to predict aesthetic scores for audio files along 4 axes:
|
37 |
+
- Content Enjoyment (CE)
|
38 |
+
- Content Usefulness (CU)
|
39 |
+
- Production Complexity (PC)
|
40 |
+
- Production Quality (PQ)
|
41 |
+
|
42 |
+
Scores range from 0 to 10.
|
43 |
+
|
44 |
+
For more details, see the [paper](https://arxiv.org/abs/2502.05139) or [code](https://github.com/facebookresearch/audiobox-aesthetics/tree/main).
|
45 |
+
""")
|
46 |
+
|
47 |
+
with gr.Row():
|
48 |
+
with gr.Group():
|
49 |
+
with gr.Column():
|
50 |
+
audio_input = gr.Audio(
|
51 |
+
sources="upload", type="filepath", label="Upload Audio"
|
52 |
+
)
|
53 |
+
submit_button = gr.Button("Predict", variant="primary")
|
54 |
+
with gr.Group():
|
55 |
+
with gr.Column():
|
56 |
+
output_data = gr.Dataframe(
|
57 |
+
headers=["Axes name", "Score"],
|
58 |
+
datatype=["str", "number"],
|
59 |
+
label="Aesthetic Scorest",
|
60 |
+
)
|
61 |
+
output_text = gr.Textbox(label="Raw prediction", interactive=False)
|
62 |
+
|
63 |
+
submit_button.click(
|
64 |
+
predict_aesthetics,
|
65 |
+
inputs=audio_input,
|
66 |
+
outputs=[output_text, output_data],
|
67 |
+
)
|
68 |
+
|
69 |
+
# Add examples
|
70 |
+
gr.Examples(
|
71 |
+
examples=[
|
72 |
+
"sample_audio/libritts_spk-84.wav",
|
73 |
+
"sample_audio/libritts_spk-3170.wav",
|
74 |
+
],
|
75 |
+
inputs=audio_input,
|
76 |
+
outputs=[output_text, output_data],
|
77 |
+
fn=predict_aesthetics,
|
78 |
+
cache_examples=True,
|
79 |
+
)
|
80 |
+
|
81 |
+
return demo
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
demo = create_demo()
|
86 |
+
demo.launch()
|
src/audiobox_aesthetics/export_model_to_hf.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from audiobox_aesthetics.inference import AudioBoxAesthetics
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
# Set up argument parser
|
10 |
+
parser = argparse.ArgumentParser(
|
11 |
+
description="Download and test AudioBox Aesthetics model"
|
12 |
+
)
|
13 |
+
parser.add_argument(
|
14 |
+
"--checkpoint-url",
|
15 |
+
default="https://dl.fbaipublicfiles.com/audiobox-aesthetics/checkpoint.pt",
|
16 |
+
help="URL for the base checkpoint",
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--model-name",
|
20 |
+
default="audiobox-aesthetics",
|
21 |
+
help="Name to save/load the pretrained model",
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--push-to-hub",
|
25 |
+
action="store_true",
|
26 |
+
help="Push the model to the Hugging Face Hub",
|
27 |
+
)
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
checkpoint_local_path = "base_checkpoint.pth"
|
31 |
+
|
32 |
+
if not os.path.exists(checkpoint_local_path):
|
33 |
+
print("Downloading base checkpoint")
|
34 |
+
response = requests.get(args.checkpoint_url)
|
35 |
+
with open(checkpoint_local_path, "wb") as f:
|
36 |
+
f.write(response.content)
|
37 |
+
|
38 |
+
# get model config from the base checkpoint
|
39 |
+
checkpoint = torch.load(
|
40 |
+
checkpoint_local_path, map_location="cpu", weights_only=True
|
41 |
+
)
|
42 |
+
model_cfg = checkpoint["model_cfg"]
|
43 |
+
|
44 |
+
# extract normalization params from the base checkpoint
|
45 |
+
target_transform = checkpoint["target_transform"]
|
46 |
+
|
47 |
+
target_transform = {
|
48 |
+
axis: {
|
49 |
+
"mean": checkpoint["target_transform"][axis]["mean"],
|
50 |
+
"std": checkpoint["target_transform"][axis]["std"],
|
51 |
+
}
|
52 |
+
for axis in target_transform.keys()
|
53 |
+
}
|
54 |
+
|
55 |
+
model = AudioBoxAesthetics(
|
56 |
+
sample_rate=16_000, target_transform=target_transform, **model_cfg
|
57 |
+
)
|
58 |
+
|
59 |
+
model._load_base_checkpoint(checkpoint_local_path)
|
60 |
+
print("✅ Loaded model from base checkpoint")
|
61 |
+
|
62 |
+
model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)
|
63 |
+
print(f"✅ Saved model to {args.model_name}")
|
64 |
+
if args.push_to_hub:
|
65 |
+
model.push_to_hub(args.model_name)
|
66 |
+
print(f"✅ Pushed model to Hub under {args.model_name}")
|
67 |
+
|
68 |
+
# test load from pretrained
|
69 |
+
model = AudioBoxAesthetics.from_pretrained(args.model_name)
|
70 |
+
model.eval()
|
71 |
+
print(f"✅ Loaded model from pretrained {args.model_name}")
|
72 |
+
|
73 |
+
# test inference
|
74 |
+
wav = model.load_audio("sample_audio/libritts_spk-84.wav")
|
75 |
+
predictions = model.predict_from_wavs(wav)
|
76 |
+
print(predictions)
|
77 |
+
print("✅ Inference test passed")
|
src/audiobox_aesthetics/infer.py
CHANGED
@@ -14,7 +14,7 @@ import torch
|
|
14 |
import torchaudio
|
15 |
import torch.nn.functional as F
|
16 |
|
17 |
-
from .model.aes_wavlm import Normalize, WavlmAudioEncoderMultiOutput
|
18 |
|
19 |
Batch = Dict[str, Any]
|
20 |
|
@@ -113,6 +113,8 @@ class AesWavlmPredictorMultiOutput:
|
|
113 |
"bf16": torch.bfloat16,
|
114 |
}.get(self.precision)
|
115 |
|
|
|
|
|
116 |
self.target_transform = {
|
117 |
axis: Normalize(
|
118 |
mean=ckpt["target_transform"][axis]["mean"],
|
@@ -205,8 +207,8 @@ def main_predict(input_file, ckpt, batch_size=10):
|
|
205 |
for ii in tqdm(range(0, len(metadata), batch_size)):
|
206 |
output = predictor.forward(metadata[ii : ii + batch_size])
|
207 |
outputs.extend(output)
|
208 |
-
assert len(outputs) == len(
|
209 |
-
metadata
|
210 |
-
)
|
211 |
|
212 |
return outputs
|
|
|
14 |
import torchaudio
|
15 |
import torch.nn.functional as F
|
16 |
|
17 |
+
from audiobox_aesthetics.model.aes_wavlm import Normalize, WavlmAudioEncoderMultiOutput
|
18 |
|
19 |
Batch = Dict[str, Any]
|
20 |
|
|
|
113 |
"bf16": torch.bfloat16,
|
114 |
}.get(self.precision)
|
115 |
|
116 |
+
print("using precision", self.precision)
|
117 |
+
|
118 |
self.target_transform = {
|
119 |
axis: Normalize(
|
120 |
mean=ckpt["target_transform"][axis]["mean"],
|
|
|
207 |
for ii in tqdm(range(0, len(metadata), batch_size)):
|
208 |
output = predictor.forward(metadata[ii : ii + batch_size])
|
209 |
outputs.extend(output)
|
210 |
+
assert len(outputs) == len(metadata), (
|
211 |
+
f"Output {len(outputs)} != input {len(metadata)} length"
|
212 |
+
)
|
213 |
|
214 |
return outputs
|
src/audiobox_aesthetics/inference.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from huggingface_hub import PyTorchModelHubMixin
|
6 |
+
|
7 |
+
from audiobox_aesthetics.model.aes_wavlm import Normalize, WavlmAudioEncoderMultiOutput
|
8 |
+
from audiobox_aesthetics.infer import make_inference_batch
|
9 |
+
|
10 |
+
from pydantic import BaseModel
|
11 |
+
import torchaudio
|
12 |
+
|
13 |
+
from pydantic import BaseModel, Field
|
14 |
+
from typing import Optional, List
|
15 |
+
import json
|
16 |
+
|
17 |
+
AXIS_NAME_LOOKUP = {
|
18 |
+
"CE": "Content Enjoyment",
|
19 |
+
"CU": "Content Usefulness",
|
20 |
+
"PC": "Production Complexity",
|
21 |
+
"PQ": "Production Quality",
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
class AudioFile(BaseModel):
|
26 |
+
"""
|
27 |
+
Audio file to be processed
|
28 |
+
"""
|
29 |
+
|
30 |
+
path: str
|
31 |
+
start_time: Optional[float] = Field(None, description="Start time in seconds")
|
32 |
+
end_time: Optional[float] = Field(None, description="End time in seconds")
|
33 |
+
|
34 |
+
|
35 |
+
class AudioFileList(BaseModel):
|
36 |
+
"""
|
37 |
+
List of audio files to be processed
|
38 |
+
"""
|
39 |
+
|
40 |
+
files: List[AudioFile]
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def from_jsonl(cls, filename: str) -> "AudioFileList":
|
44 |
+
audio_files = []
|
45 |
+
with open(filename, "r") as f:
|
46 |
+
for line in f:
|
47 |
+
data = json.loads(line.strip())
|
48 |
+
audio_file = AudioFile(**data)
|
49 |
+
audio_files.append(audio_file)
|
50 |
+
return cls(files=audio_files)
|
51 |
+
|
52 |
+
|
53 |
+
# model
|
54 |
+
|
55 |
+
|
56 |
+
class AudioBoxAesthetics(
|
57 |
+
nn.Module,
|
58 |
+
PyTorchModelHubMixin,
|
59 |
+
library_name="audiobox-aesthetics",
|
60 |
+
repo_url="https://github.com/facebookresearch/audiobox-aesthetics",
|
61 |
+
):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
proj_num_layer: int = 1,
|
65 |
+
proj_ln: bool = False,
|
66 |
+
proj_act_fn: str = "gelu",
|
67 |
+
proj_dropout: float = 0.0,
|
68 |
+
nth_layer: int = 13,
|
69 |
+
use_weighted_layer_sum: bool = True,
|
70 |
+
precision: str = "32",
|
71 |
+
normalize_embed: bool = True,
|
72 |
+
output_dim: int = 1,
|
73 |
+
target_transform: dict = None,
|
74 |
+
sample_rate: int = 16_000,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self.sample_rate = sample_rate
|
78 |
+
self.encoder = WavlmAudioEncoderMultiOutput(
|
79 |
+
proj_num_layer=proj_num_layer,
|
80 |
+
proj_ln=proj_ln,
|
81 |
+
proj_act_fn=proj_act_fn,
|
82 |
+
proj_dropout=proj_dropout,
|
83 |
+
nth_layer=nth_layer,
|
84 |
+
use_weighted_layer_sum=use_weighted_layer_sum,
|
85 |
+
precision=precision,
|
86 |
+
normalize_embed=normalize_embed,
|
87 |
+
output_dim=output_dim,
|
88 |
+
)
|
89 |
+
self.target_transform = {
|
90 |
+
axis: Normalize(
|
91 |
+
mean=target_transform[axis]["mean"],
|
92 |
+
std=target_transform[axis]["std"],
|
93 |
+
)
|
94 |
+
for axis in target_transform.keys()
|
95 |
+
}
|
96 |
+
|
97 |
+
def _load_base_checkpoint(self, checkpoint_pth: str):
|
98 |
+
with open(checkpoint_pth, "rb") as fin:
|
99 |
+
ckpt = torch.load(fin, map_location="cpu", weights_only=True)
|
100 |
+
state_dict = {
|
101 |
+
re.sub("^model.", "", k): v for (k, v) in ckpt["state_dict"].items()
|
102 |
+
}
|
103 |
+
|
104 |
+
self.encoder.load_state_dict(state_dict)
|
105 |
+
|
106 |
+
def forward(self, batch, inference_mode: bool = True):
|
107 |
+
if inference_mode:
|
108 |
+
with torch.inference_mode():
|
109 |
+
result = self.encoder(batch)
|
110 |
+
else:
|
111 |
+
result = self.encoder(batch)
|
112 |
+
return result
|
113 |
+
|
114 |
+
def _process_single_audio(self, wav: torch.Tensor, sample_rate: int):
|
115 |
+
"""
|
116 |
+
Process a single audio file to the target sample rate and return a tensor of shape (1, 1, T)
|
117 |
+
"""
|
118 |
+
target_sample_rate = self.sample_rate
|
119 |
+
wav = torchaudio.functional.resample(wav, sample_rate, target_sample_rate)
|
120 |
+
|
121 |
+
# convert to mono
|
122 |
+
if wav.shape[0] > 1:
|
123 |
+
wav = wav.mean(dim=0, keepdim=True)
|
124 |
+
return wav, target_sample_rate
|
125 |
+
|
126 |
+
def load_audio(self, path: str, start_time: float = None, end_time: float = None):
|
127 |
+
"""
|
128 |
+
Load an audio file form path
|
129 |
+
|
130 |
+
Args:
|
131 |
+
path: str - path to the audio file
|
132 |
+
start_time: float - start time in seconds
|
133 |
+
end_time: float - end time in seconds
|
134 |
+
Returns:
|
135 |
+
wav: torch.Tensor - audio tensor of shape (1, 1, T)
|
136 |
+
"""
|
137 |
+
wav, sample_rate = torchaudio.load(path)
|
138 |
+
if start_time is not None and end_time is not None:
|
139 |
+
if start_time and end_time:
|
140 |
+
wav = wav[
|
141 |
+
:, int(start_time * sample_rate) : int(end_time * sample_rate)
|
142 |
+
]
|
143 |
+
elif start_time:
|
144 |
+
wav = wav[:, int(start_time * sample_rate) :]
|
145 |
+
elif end_time:
|
146 |
+
wav = wav[:, : int(end_time * sample_rate)]
|
147 |
+
|
148 |
+
wav, _sr = self._process_single_audio(wav, sample_rate)
|
149 |
+
|
150 |
+
return wav
|
151 |
+
|
152 |
+
def predict_from_files(
|
153 |
+
self, audio_file_list: AudioFileList | AudioFile
|
154 |
+
) -> List[dict]:
|
155 |
+
"""
|
156 |
+
Predict the aesthetic score for a list of audio files
|
157 |
+
"""
|
158 |
+
if isinstance(audio_file_list, AudioFile):
|
159 |
+
audio_file_list = AudioFileList(files=[audio_file_list])
|
160 |
+
|
161 |
+
wavs = [
|
162 |
+
self.load_audio(file.path, file.start_time, file.end_time)
|
163 |
+
for file in audio_file_list.files
|
164 |
+
]
|
165 |
+
|
166 |
+
return self.predict_from_wavs(wavs)
|
167 |
+
|
168 |
+
def predict_from_wavs(self, wavs: List[torch.Tensor] | torch.Tensor):
|
169 |
+
"""
|
170 |
+
Predict the aesthetic score for a single audio file
|
171 |
+
|
172 |
+
Args:
|
173 |
+
wavs: List[torch.Tensor] - list of audio tensors of shape (1, 1, T) - must be at the sample rate of the model
|
174 |
+
Returns:
|
175 |
+
preds: List[dict] - list of dictionaries containing the aesthetic scores for each axis
|
176 |
+
"""
|
177 |
+
|
178 |
+
if isinstance(wavs, torch.Tensor):
|
179 |
+
wavs = [wavs]
|
180 |
+
|
181 |
+
n_wavs = len(wavs)
|
182 |
+
|
183 |
+
wavs, masks, weights, bids = make_inference_batch(
|
184 |
+
wavs,
|
185 |
+
10,
|
186 |
+
10,
|
187 |
+
sample_rate=self.sample_rate,
|
188 |
+
)
|
189 |
+
|
190 |
+
# stack wavs, masks, weights, bids
|
191 |
+
wavs = torch.stack(wavs)
|
192 |
+
masks = torch.stack(masks)
|
193 |
+
weights = torch.tensor(weights)
|
194 |
+
bids = torch.tensor(bids)
|
195 |
+
|
196 |
+
if not wavs.shape[0] == masks.shape[0] == weights.shape[0] == bids.shape[0]:
|
197 |
+
raise ValueError("Batch size mismatch")
|
198 |
+
|
199 |
+
preds_all = self.forward({"wav": wavs, "mask": masks})
|
200 |
+
all_result = {}
|
201 |
+
|
202 |
+
# predict scores across all axis
|
203 |
+
for axis in self.target_transform.keys():
|
204 |
+
preds = self.target_transform[axis].inverse(preds_all[axis])
|
205 |
+
weighted_preds = []
|
206 |
+
for bii in range(n_wavs):
|
207 |
+
weights_bii = weights[bids == bii]
|
208 |
+
weighted_preds.append(
|
209 |
+
(
|
210 |
+
(preds[bids == bii] * weights_bii).sum() / weights_bii.sum()
|
211 |
+
).item()
|
212 |
+
)
|
213 |
+
all_result[axis] = weighted_preds
|
214 |
+
# re-arrenge result
|
215 |
+
preds = [dict(zip(all_result.keys(), vv)) for vv in zip(*all_result.values())]
|
216 |
+
|
217 |
+
return preds
|
src/audiobox_aesthetics/model/aes_wavlm.py
CHANGED
@@ -9,8 +9,8 @@ import sys
|
|
9 |
from torch import nn
|
10 |
import torch
|
11 |
|
12 |
-
from .utils import create_mlp_block
|
13 |
-
from .wavlm import WavLM, WavLMConfig
|
14 |
|
15 |
|
16 |
DEFAULT_AUDIO_CFG = WavLMConfig(
|
|
|
9 |
from torch import nn
|
10 |
import torch
|
11 |
|
12 |
+
from audiobox_aesthetics.model.utils import create_mlp_block
|
13 |
+
from audiobox_aesthetics.model.wavlm import WavLM, WavLMConfig
|
14 |
|
15 |
|
16 |
DEFAULT_AUDIO_CFG = WavLMConfig(
|
src/audiobox_aesthetics/model/wavlm.py
CHANGED
@@ -244,17 +244,17 @@ def quant_noise(module, p, block_size):
|
|
244 |
|
245 |
# 2D matrix
|
246 |
if not is_conv:
|
247 |
-
assert (
|
248 |
-
|
249 |
-
)
|
250 |
|
251 |
# 4D matrix
|
252 |
else:
|
253 |
# 1x1 convolutions
|
254 |
if module.kernel_size == (1, 1):
|
255 |
-
assert (
|
256 |
-
|
257 |
-
)
|
258 |
# regular convolutions
|
259 |
else:
|
260 |
k = module.kernel_size[0] * module.kernel_size[1]
|
@@ -356,16 +356,16 @@ class MultiheadAttention(nn.Module):
|
|
356 |
self.head_dim = embed_dim // num_heads
|
357 |
self.q_head_dim = self.head_dim
|
358 |
self.k_head_dim = self.head_dim
|
359 |
-
assert (
|
360 |
-
|
361 |
-
)
|
362 |
self.scaling = self.head_dim**-0.5
|
363 |
|
364 |
self.self_attention = self_attention
|
365 |
self.encoder_decoder_attention = encoder_decoder_attention
|
366 |
|
367 |
assert not self.self_attention or self.qkv_same_dim, (
|
368 |
-
"Self-attention requires query, key and
|
369 |
)
|
370 |
|
371 |
k_bias = True
|
@@ -1255,9 +1255,9 @@ class ConvFeatureExtractionModel(nn.Module):
|
|
1255 |
nn.init.kaiming_normal_(conv.weight)
|
1256 |
return conv
|
1257 |
|
1258 |
-
assert (
|
1259 |
-
|
1260 |
-
)
|
1261 |
|
1262 |
if is_layer_norm:
|
1263 |
return nn.Sequential(
|
|
|
244 |
|
245 |
# 2D matrix
|
246 |
if not is_conv:
|
247 |
+
assert module.weight.size(1) % block_size == 0, (
|
248 |
+
"Input features must be a multiple of block sizes"
|
249 |
+
)
|
250 |
|
251 |
# 4D matrix
|
252 |
else:
|
253 |
# 1x1 convolutions
|
254 |
if module.kernel_size == (1, 1):
|
255 |
+
assert module.in_channels % block_size == 0, (
|
256 |
+
"Input channels must be a multiple of block sizes"
|
257 |
+
)
|
258 |
# regular convolutions
|
259 |
else:
|
260 |
k = module.kernel_size[0] * module.kernel_size[1]
|
|
|
356 |
self.head_dim = embed_dim // num_heads
|
357 |
self.q_head_dim = self.head_dim
|
358 |
self.k_head_dim = self.head_dim
|
359 |
+
assert self.head_dim * num_heads == self.embed_dim, (
|
360 |
+
"embed_dim must be divisible by num_heads"
|
361 |
+
)
|
362 |
self.scaling = self.head_dim**-0.5
|
363 |
|
364 |
self.self_attention = self_attention
|
365 |
self.encoder_decoder_attention = encoder_decoder_attention
|
366 |
|
367 |
assert not self.self_attention or self.qkv_same_dim, (
|
368 |
+
"Self-attention requires query, key and value to be of the same size"
|
369 |
)
|
370 |
|
371 |
k_bias = True
|
|
|
1255 |
nn.init.kaiming_normal_(conv.weight)
|
1256 |
return conv
|
1257 |
|
1258 |
+
assert (is_layer_norm and is_group_norm) is False, (
|
1259 |
+
"layer norm and group norm are exclusive"
|
1260 |
+
)
|
1261 |
|
1262 |
if is_layer_norm:
|
1263 |
return nn.Sequential(
|
test/test_inference.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from audiobox_aesthetics.inference import AudioBoxAesthetics, AudioFileList, AudioFile
|
2 |
+
|
3 |
+
# cached results from running the CLI
|
4 |
+
cli_results = {
|
5 |
+
"sample_audio/libritts_spk-84.wav": {
|
6 |
+
"CE": 6.1027421951293945,
|
7 |
+
"CU": 6.3574299812316895,
|
8 |
+
"PC": 1.7401179075241089,
|
9 |
+
"PQ": 6.733065128326416,
|
10 |
+
},
|
11 |
+
}
|
12 |
+
|
13 |
+
|
14 |
+
def test_inference():
|
15 |
+
audio_path = "sample_audio/libritts_spk-84.wav"
|
16 |
+
audio_file = AudioFile(path=audio_path)
|
17 |
+
model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
|
18 |
+
model.eval()
|
19 |
+
|
20 |
+
predictions = model.predict_from_files(audio_file)
|
21 |
+
single_pred = predictions[0]
|
22 |
+
|
23 |
+
print(single_pred)
|
24 |
+
|
25 |
+
assert single_pred["CE"] == cli_results[audio_path]["CE"]
|
26 |
+
assert single_pred["CU"] == cli_results[audio_path]["CU"]
|
27 |
+
assert single_pred["PC"] == cli_results[audio_path]["PC"]
|
28 |
+
assert single_pred["PQ"] == cli_results[audio_path]["PQ"]
|
29 |
+
|
30 |
+
|
31 |
+
def test_inference_load_from_jsonl():
|
32 |
+
audio_file_list = AudioFileList.from_jsonl("sample_audio/test.jsonl")
|
33 |
+
model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
|
34 |
+
model.eval()
|
35 |
+
|
36 |
+
predictions = model.predict_from_files(audio_file_list)
|
37 |
+
|
38 |
+
single_pred = predictions[0]
|
39 |
+
assert single_pred["CE"] == cli_results[audio_file_list.files[0].path]["CE"]
|
40 |
+
assert single_pred["CU"] == cli_results[audio_file_list.files[0].path]["CU"]
|
41 |
+
assert single_pred["PC"] == cli_results[audio_file_list.files[0].path]["PC"]
|
42 |
+
assert single_pred["PQ"] == cli_results[audio_file_list.files[0].path]["PQ"]
|
43 |
+
|
44 |
+
|
45 |
+
def test_inference_twice_on_same_audio_yields_same_result():
|
46 |
+
audio_file = AudioFile(path="sample_audio/libritts_spk-84.wav")
|
47 |
+
model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
|
48 |
+
model.eval()
|
49 |
+
|
50 |
+
predictions_a = model.predict_from_files(audio_file)
|
51 |
+
predictions_b = model.predict_from_files(audio_file)
|
52 |
+
|
53 |
+
single_pred_a = predictions_a[0]
|
54 |
+
single_pred_b = predictions_b[0]
|
55 |
+
|
56 |
+
assert single_pred_a["CE"] == single_pred_b["CE"]
|
57 |
+
assert single_pred_a["CU"] == single_pred_b["CU"]
|
58 |
+
assert single_pred_a["PC"] == single_pred_b["PC"]
|
59 |
+
assert single_pred_a["PQ"] == single_pred_b["PQ"]
|
60 |
+
|
61 |
+
|
62 |
+
def test_loading_from_wav():
|
63 |
+
audio_path = "sample_audio/libritts_spk-84.wav"
|
64 |
+
model = AudioBoxAesthetics.from_pretrained("audiobox-aesthetics")
|
65 |
+
model.eval()
|
66 |
+
|
67 |
+
wav = model.load_audio(audio_path)
|
68 |
+
predictions = model.predict_from_wavs(wav)
|
69 |
+
|
70 |
+
single_pred = predictions[0]
|
71 |
+
assert single_pred["CE"] == cli_results[audio_path]["CE"]
|
72 |
+
assert single_pred["CU"] == cli_results[audio_path]["CU"]
|
73 |
+
assert single_pred["PC"] == cli_results[audio_path]["PC"]
|
74 |
+
assert single_pred["PQ"] == cli_results[audio_path]["PQ"]
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|