vampnet / scripts /utils /vamp_folder.py
Hugo Flores Garcia
exps
6f6fd13
raw
history blame
3 kB
from pathlib import Path
import argbind
from tqdm import tqdm
import torch
from vampnet.interface import Interface
Interface = argbind.bind(Interface, positional=True)
def baseline(sig, interface):
return sig
def reconstructed(sig, interface):
return interface.to_signal(
interface.encode(sig)
)
def coarse2fine(sig, interface):
z = interface.encode(sig)
z = z[:, :interface.c2f.n_conditioning_codebooks, :]
z = interface.coarse_to_fine(z)
return interface.to_signal(z)
def one_codebook(sig, interface):
z = interface.encode(sig)
mask = torch.zeros_like(z)
mask[:, 1:, :] = 1
zv = interface.coarse_vamp_v2(
sig, ext_mask=mask,
)
zv = interface.coarse_to_fine(zv)
return interface.to_signal(zv)
def four_codebooks_downsampled_4x(sig, interface):
zv = interface.coarse_vamp_v2(
sig, downsample_factor=4
)
zv = interface.coarse_to_fine(zv)
return interface.to_signal(zv)
def two_codebooks_downsampled_4x(sig, interface):
z = interface.encode(sig)
mask = torch.zeros_like(z)
mask[:, 2:, :] = 1
zv = interface.coarse_vamp_v2(
sig, ext_mask=mask, downsample_factor=4
)
zv = interface.coarse_to_fine(zv)
return interface.to_signal(zv)
def four_codebooks_downsampled_8x(sig, interface):
zv = interface.coarse_vamp_v2(
sig, downsample_factor=8
)
zv = interface.coarse_to_fine(zv)
return interface.to_signal(zv)
SAMPLE_CONDS ={
"baseline": baseline,
"reconstructed": reconstructed,
"coarse2fine": coarse2fine,
"one_codebook": one_codebook,
"four_codebooks_downsampled_4x": four_codebooks_downsampled_4x,
"two_codebooks_downsampled_4x": two_codebooks_downsampled_4x,
"four_codebooks_downsampled_8x": four_codebooks_downsampled_8x,
}
@argbind.bind(without_prefix=True)
def main(
sources=[
"/data/spotdl/audio/val", "/data/spotdl/audio/test"
],
output_dir: str = "./samples",
max_excerpts: int = 5000,
):
interface = Interface()
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
from audiotools.data.datasets import AudioLoader, AudioDataset
loader = AudioLoader(sources=sources)
dataset = AudioDataset(loader,
sample_rate=interface.codec.sample_rate,
duration=interface.coarse.chunk_size_s,
n_examples=max_excerpts,
without_replacement=True,
)
for i in tqdm(range(max_excerpts)):
sig = dataset[i]["signal"]
results = {
name: cond(sig, interface)
for name, cond in SAMPLE_CONDS.items()
}
for name, sig in results.items():
output_dir = Path(output_dir) / name
output_dir.mkdir(exist_ok=True, parents=True)
sig.write(output_dir / f"{i}.wav")
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
main()