Spaces:
Runtime error
Runtime error
File size: 2,998 Bytes
6f6fd13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from pathlib import Path
import argbind
from tqdm import tqdm
import torch
from vampnet.interface import Interface
Interface = argbind.bind(Interface, positional=True)
def baseline(sig, interface):
return sig
def reconstructed(sig, interface):
return interface.to_signal(
interface.encode(sig)
)
def coarse2fine(sig, interface):
z = interface.encode(sig)
z = z[:, :interface.c2f.n_conditioning_codebooks, :]
z = interface.coarse_to_fine(z)
return interface.to_signal(z)
def one_codebook(sig, interface):
z = interface.encode(sig)
mask = torch.zeros_like(z)
mask[:, 1:, :] = 1
zv = interface.coarse_vamp_v2(
sig, ext_mask=mask,
)
zv = interface.coarse_to_fine(zv)
return interface.to_signal(zv)
def four_codebooks_downsampled_4x(sig, interface):
zv = interface.coarse_vamp_v2(
sig, downsample_factor=4
)
zv = interface.coarse_to_fine(zv)
return interface.to_signal(zv)
def two_codebooks_downsampled_4x(sig, interface):
z = interface.encode(sig)
mask = torch.zeros_like(z)
mask[:, 2:, :] = 1
zv = interface.coarse_vamp_v2(
sig, ext_mask=mask, downsample_factor=4
)
zv = interface.coarse_to_fine(zv)
return interface.to_signal(zv)
def four_codebooks_downsampled_8x(sig, interface):
zv = interface.coarse_vamp_v2(
sig, downsample_factor=8
)
zv = interface.coarse_to_fine(zv)
return interface.to_signal(zv)
SAMPLE_CONDS ={
"baseline": baseline,
"reconstructed": reconstructed,
"coarse2fine": coarse2fine,
"one_codebook": one_codebook,
"four_codebooks_downsampled_4x": four_codebooks_downsampled_4x,
"two_codebooks_downsampled_4x": two_codebooks_downsampled_4x,
"four_codebooks_downsampled_8x": four_codebooks_downsampled_8x,
}
@argbind.bind(without_prefix=True)
def main(
sources=[
"/data/spotdl/audio/val", "/data/spotdl/audio/test"
],
output_dir: str = "./samples",
max_excerpts: int = 5000,
):
interface = Interface()
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
from audiotools.data.datasets import AudioLoader, AudioDataset
loader = AudioLoader(sources=sources)
dataset = AudioDataset(loader,
sample_rate=interface.codec.sample_rate,
duration=interface.coarse.chunk_size_s,
n_examples=max_excerpts,
without_replacement=True,
)
for i in tqdm(range(max_excerpts)):
sig = dataset[i]["signal"]
results = {
name: cond(sig, interface)
for name, cond in SAMPLE_CONDS.items()
}
for name, sig in results.items():
output_dir = Path(output_dir) / name
output_dir.mkdir(exist_ok=True, parents=True)
sig.write(output_dir / f"{i}.wav")
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
main()
|