Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
99122c4
1
Parent(s):
5a343f4
basic readme stuff
Browse files- README.md +31 -53
- conf/lora/gas-station.yml +10 -0
- demo.py +3 -2
- scripts/exp/train.py +9 -87
- scripts/utils/vamp_folder.py +5 -5
- vampnet/interface.py +13 -5
- vampnet/modules/base.py +8 -119
- vampnet/modules/layers.py +14 -0
- vampnet/signal.py +5 -0
- vampnet/util.py +3 -34
README.md
CHANGED
|
@@ -1,80 +1,58 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
-
This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec.
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
-
##
|
| 7 |
-
### Setting everything up
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
```bash
|
| 12 |
-
|
|
|
|
| 13 |
```
|
| 14 |
|
| 15 |
-
|
| 16 |
-
Once run, follow the instructions it prints out to create your
|
| 17 |
-
environment file, which will be at `env/env.sh`.
|
| 18 |
-
|
| 19 |
-
Note that if this is a new machine, and
|
| 20 |
-
the data is not downloaded somewhere on it already, it will ask you
|
| 21 |
-
for a directory to download the data to.
|
| 22 |
-
|
| 23 |
-
For Github setup, if you don't have a .netrc token, create one by going to your Github profile -> Developer settings -> Personal access tokens -> Generate new token. Copy the token and [keep it secret, keep it safe](https://www.youtube.com/watch?v=iThtELZvfPs).
|
| 24 |
-
|
| 25 |
-
When complete, run:
|
| 26 |
|
| 27 |
```bash
|
| 28 |
-
|
|
|
|
| 29 |
```
|
| 30 |
|
| 31 |
-
|
| 32 |
|
| 33 |
```bash
|
| 34 |
-
|
|
|
|
| 35 |
```
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
|
| 41 |
-
|
| 42 |
|
| 43 |
-
|
| 44 |
-
docker compose run dev
|
| 45 |
-
```
|
| 46 |
-
|
| 47 |
-
To tear down your development environment, just do
|
| 48 |
-
|
| 49 |
-
```bash
|
| 50 |
-
docker compose down
|
| 51 |
-
```
|
| 52 |
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
`stage` creates a directory with a copy of all of the Git-tracked files in the root repository.`stage` launches a shell into said directory, so all commands are run on the
|
| 59 |
-
copy of the original repository code. This is useful for rewinding to an old experiment
|
| 60 |
-
and resuming it, for example. Even if the repository code changes, the snapshot in the experiment directory is unchanged from the original run, so it can be re-used.
|
| 61 |
-
|
| 62 |
-
Then, the experiment can be run via:
|
| 63 |
|
| 64 |
```bash
|
| 65 |
-
|
| 66 |
-
scripts/exp/train.py \
|
| 67 |
-
--args.load=conf/args.yml \
|
| 68 |
```
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
#### Cleaning up after a run
|
| 75 |
-
|
| 76 |
-
Sometimes DDP runs fail to clear themselves out of the machine. To fix this, run
|
| 77 |
-
|
| 78 |
```bash
|
| 79 |
-
|
| 80 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VampNet
|
| 2 |
|
| 3 |
+
This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec.
|
| 4 |
|
| 5 |
+
# Setting up
|
| 6 |
|
| 7 |
+
## Install LAC
|
|
|
|
| 8 |
|
| 9 |
+
install AudioTools
|
| 10 |
|
| 11 |
```bash
|
| 12 |
+
git clone https://github.com/hugofloresgarcia/audiotools.git
|
| 13 |
+
pip install -e ./audiotools
|
| 14 |
```
|
| 15 |
|
| 16 |
+
install the LAC library.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
```bash
|
| 19 |
+
git clone https://github.com/hugofloresgarcia/lac.git
|
| 20 |
+
pip install -e ./lac
|
| 21 |
```
|
| 22 |
|
| 23 |
+
install VampNet
|
| 24 |
|
| 25 |
```bash
|
| 26 |
+
git clone https://github.com/hugofloresgarcia/vampnet2.git
|
| 27 |
+
pip install -e ./vampnet2
|
| 28 |
```
|
| 29 |
|
| 30 |
+
## A note on Argbind
|
| 31 |
+
This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files.
|
| 32 |
+
Config files are stored in the `conf/` folder.
|
| 33 |
|
| 34 |
+
# Usage
|
| 35 |
|
| 36 |
+
## Staging a Run
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
Staging a run makes a copy of all the git-tracked files in the codebase and saves them to a folder for reproducibility. You can then run the training script from the staged folder.
|
| 39 |
|
| 40 |
+
coming soon
|
| 41 |
|
| 42 |
+
## Training a model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
```bash
|
| 45 |
+
python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints
|
|
|
|
|
|
|
| 46 |
```
|
| 47 |
|
| 48 |
+
## Fine-tuning
|
| 49 |
+
To fine-tune a model, see the configuration files under `conf/lora/`.
|
| 50 |
+
You just need to provide a list of audio files // folders to fine-tune on, then launch the training job as usual.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
```bash
|
| 52 |
+
python scripts/exp/train.py --args.load conf/lora/birds.yml --save_path /path/to/checkpoints
|
| 53 |
```
|
| 54 |
+
|
| 55 |
+
## Launching the Gradio Interface
|
| 56 |
+
```bash
|
| 57 |
+
python demo.py --args.load conf/interface/spotdl.yml --Interface.device cuda
|
| 58 |
+
```
|
conf/lora/gas-station.yml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
$include:
|
| 2 |
+
- conf/lora/lora.yml
|
| 3 |
+
|
| 4 |
+
fine_tune: True
|
| 5 |
+
|
| 6 |
+
train/AudioLoader.sources:
|
| 7 |
+
- /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
|
| 8 |
+
|
| 9 |
+
val/AudioLoader.sources:
|
| 10 |
+
- /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
|
demo.py
CHANGED
|
@@ -48,6 +48,7 @@ def load_audio(file):
|
|
| 48 |
sig.write(out_dir / "input.wav")
|
| 49 |
return sig.path_to_file
|
| 50 |
|
|
|
|
| 51 |
def load_random_audio():
|
| 52 |
index = np.random.randint(0, len(dataset))
|
| 53 |
sig = dataset[index]["signal"]
|
|
@@ -68,7 +69,7 @@ def ez_vamp(
|
|
| 68 |
sig = at.AudioSignal(input_audio)
|
| 69 |
|
| 70 |
print(f"running standard vampnet with {num_vamps} vamps")
|
| 71 |
-
zv = interface.
|
| 72 |
sig,
|
| 73 |
sampling_steps=num_steps,
|
| 74 |
temperature=(init_temp, final_temp),
|
|
@@ -140,7 +141,7 @@ def vamp(
|
|
| 140 |
|
| 141 |
if mode == "standard":
|
| 142 |
print(f"running standard vampnet with {num_vamps} vamps")
|
| 143 |
-
zv, mask_z = interface.
|
| 144 |
sig,
|
| 145 |
sampling_steps=num_steps,
|
| 146 |
temperature=(init_temp, final_temp),
|
|
|
|
| 48 |
sig.write(out_dir / "input.wav")
|
| 49 |
return sig.path_to_file
|
| 50 |
|
| 51 |
+
|
| 52 |
def load_random_audio():
|
| 53 |
index = np.random.randint(0, len(dataset))
|
| 54 |
sig = dataset[index]["signal"]
|
|
|
|
| 69 |
sig = at.AudioSignal(input_audio)
|
| 70 |
|
| 71 |
print(f"running standard vampnet with {num_vamps} vamps")
|
| 72 |
+
zv = interface.coarse_vamp(
|
| 73 |
sig,
|
| 74 |
sampling_steps=num_steps,
|
| 75 |
temperature=(init_temp, final_temp),
|
|
|
|
| 141 |
|
| 142 |
if mode == "standard":
|
| 143 |
print(f"running standard vampnet with {num_vamps} vamps")
|
| 144 |
+
zv, mask_z = interface.coarse_vamp(
|
| 145 |
sig,
|
| 146 |
sampling_steps=num_steps,
|
| 147 |
temperature=(init_temp, final_temp),
|
scripts/exp/train.py
CHANGED
|
@@ -115,6 +115,10 @@ def load(
|
|
| 115 |
}
|
| 116 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
| 117 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
codec = LAC.load(args["codec_ckpt"], map_location="cpu")
|
| 120 |
codec.eval()
|
|
@@ -149,25 +153,6 @@ def load(
|
|
| 149 |
}
|
| 150 |
|
| 151 |
|
| 152 |
-
def get_gpu_memory_map():
|
| 153 |
-
"""Get the current gpu usage.
|
| 154 |
-
|
| 155 |
-
Returns
|
| 156 |
-
-------
|
| 157 |
-
usage: dict
|
| 158 |
-
Keys are device ids as integers.
|
| 159 |
-
Values are memory usage as integers in MB.
|
| 160 |
-
"""
|
| 161 |
-
result = subprocess.check_output(
|
| 162 |
-
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"],
|
| 163 |
-
encoding="utf-8",
|
| 164 |
-
)
|
| 165 |
-
# Convert lines into a dictionary
|
| 166 |
-
gpu_memory = [int(x) for x in result.strip().split("\n")]
|
| 167 |
-
gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
|
| 168 |
-
gpu_memory_map = {f"gpu/{k}": v / 1024 for k, v in gpu_memory_map.items()}
|
| 169 |
-
return gpu_memory_map
|
| 170 |
-
|
| 171 |
|
| 172 |
def num_params_hook(o, p):
|
| 173 |
return o + f" {p/1e6:<.3f}M params."
|
|
@@ -189,7 +174,6 @@ def accuracy(
|
|
| 189 |
target: torch.Tensor,
|
| 190 |
top_k: int = 1,
|
| 191 |
ignore_index: Optional[int] = None,
|
| 192 |
-
**kwargs,
|
| 193 |
) -> torch.Tensor:
|
| 194 |
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
| 195 |
preds = rearrange(preds, "b p s -> (b s) p")
|
|
@@ -214,30 +198,6 @@ def accuracy(
|
|
| 214 |
|
| 215 |
return accuracy
|
| 216 |
|
| 217 |
-
def sample_prefix_suffix_amt(
|
| 218 |
-
z,
|
| 219 |
-
n_batch,
|
| 220 |
-
prefix_amt,
|
| 221 |
-
suffix_amt,
|
| 222 |
-
prefix_dropout,
|
| 223 |
-
suffix_dropout,
|
| 224 |
-
rng
|
| 225 |
-
):
|
| 226 |
-
"""
|
| 227 |
-
Sample the number of prefix and suffix tokens to drop.
|
| 228 |
-
"""
|
| 229 |
-
if prefix_amt > 0.0:
|
| 230 |
-
prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
|
| 231 |
-
n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
|
| 232 |
-
else:
|
| 233 |
-
n_prefix = None
|
| 234 |
-
if suffix_amt > 0.0:
|
| 235 |
-
suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
|
| 236 |
-
n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
|
| 237 |
-
else:
|
| 238 |
-
n_suffix = None
|
| 239 |
-
return n_prefix, n_suffix
|
| 240 |
-
|
| 241 |
|
| 242 |
@argbind.bind(without_prefix=True)
|
| 243 |
def train(
|
|
@@ -256,10 +216,6 @@ def train(
|
|
| 256 |
num_workers: int = 10,
|
| 257 |
detect_anomaly: bool = False,
|
| 258 |
grad_clip_val: float = 5.0,
|
| 259 |
-
prefix_amt: float = 0.0,
|
| 260 |
-
suffix_amt: float = 0.0,
|
| 261 |
-
prefix_dropout: float = 0.1,
|
| 262 |
-
suffix_dropout: float = 0.1,
|
| 263 |
fine_tune: bool = False,
|
| 264 |
quiet: bool = False,
|
| 265 |
):
|
|
@@ -342,16 +298,12 @@ def train(
|
|
| 342 |
target=r_unmasked_target,
|
| 343 |
ignore_index=IGNORE_INDEX,
|
| 344 |
top_k=topk,
|
| 345 |
-
task="multiclass",
|
| 346 |
-
num_classes=vn.vocab_size,
|
| 347 |
)
|
| 348 |
output[f"{tag}/masked"] = accuracy(
|
| 349 |
preds=r_z_hat,
|
| 350 |
target=r_masked_target,
|
| 351 |
ignore_index=IGNORE_INDEX,
|
| 352 |
top_k=topk,
|
| 353 |
-
task="multiclass",
|
| 354 |
-
num_classes=vn.vocab_size,
|
| 355 |
)
|
| 356 |
|
| 357 |
def train_loop(self, engine, batch):
|
|
@@ -370,15 +322,7 @@ def train(
|
|
| 370 |
n_batch = z.shape[0]
|
| 371 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 372 |
|
| 373 |
-
|
| 374 |
-
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
| 375 |
-
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
| 376 |
-
rng=rng
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
z_mask, mask = vn.add_noise(
|
| 380 |
-
z, r, n_prefix=n_prefix, n_suffix=n_suffix
|
| 381 |
-
)
|
| 382 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
| 383 |
|
| 384 |
dtype = torch.bfloat16 if accel.amp else None
|
|
@@ -454,13 +398,7 @@ def train(
|
|
| 454 |
n_batch = z.shape[0]
|
| 455 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 456 |
|
| 457 |
-
|
| 458 |
-
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
| 459 |
-
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
| 460 |
-
rng=rng
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
-
z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
|
| 464 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
| 465 |
|
| 466 |
z_hat = model(z_mask_latent, r)
|
|
@@ -574,17 +512,8 @@ def train(
|
|
| 574 |
)
|
| 575 |
|
| 576 |
def save_imputation(self, z: torch.Tensor):
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
_suffix_amt = suffix_amt
|
| 580 |
-
|
| 581 |
-
if _prefix_amt == 0:
|
| 582 |
-
_prefix_amt = 0.25
|
| 583 |
-
if _suffix_amt == 0:
|
| 584 |
-
_suffix_amt = 0.25
|
| 585 |
-
|
| 586 |
-
n_prefix = int(z.shape[-1] * _prefix_amt)
|
| 587 |
-
n_suffix = int(z.shape[-1] * _suffix_amt)
|
| 588 |
downsample_factor = None
|
| 589 |
|
| 590 |
vn = accel.unwrap(model)
|
|
@@ -647,13 +576,7 @@ def train(
|
|
| 647 |
|
| 648 |
n_batch = z.shape[0]
|
| 649 |
|
| 650 |
-
|
| 651 |
-
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
| 652 |
-
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
| 653 |
-
rng=rng
|
| 654 |
-
)
|
| 655 |
-
|
| 656 |
-
z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
|
| 657 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
| 658 |
|
| 659 |
z_hat = model(z_mask_latent, r)
|
|
@@ -664,7 +587,6 @@ def train(
|
|
| 664 |
z_pred = vn.embedding.unflatten(z_pred, n_codebooks=vn.n_predict_codebooks)
|
| 665 |
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
| 666 |
|
| 667 |
-
print("z_mask", z_mask.shape)
|
| 668 |
generated = vn.to_signal(z_pred, codec)
|
| 669 |
reconstructed = vn.to_signal(z, codec)
|
| 670 |
masked = vn.to_signal(z_mask.squeeze(1), codec)
|
|
|
|
| 115 |
}
|
| 116 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
| 117 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
f"Could not find a VampNet checkpoint in {kwargs['folder']}"
|
| 121 |
+
)
|
| 122 |
|
| 123 |
codec = LAC.load(args["codec_ckpt"], map_location="cpu")
|
| 124 |
codec.eval()
|
|
|
|
| 153 |
}
|
| 154 |
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
def num_params_hook(o, p):
|
| 158 |
return o + f" {p/1e6:<.3f}M params."
|
|
|
|
| 174 |
target: torch.Tensor,
|
| 175 |
top_k: int = 1,
|
| 176 |
ignore_index: Optional[int] = None,
|
|
|
|
| 177 |
) -> torch.Tensor:
|
| 178 |
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
| 179 |
preds = rearrange(preds, "b p s -> (b s) p")
|
|
|
|
| 198 |
|
| 199 |
return accuracy
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
@argbind.bind(without_prefix=True)
|
| 203 |
def train(
|
|
|
|
| 216 |
num_workers: int = 10,
|
| 217 |
detect_anomaly: bool = False,
|
| 218 |
grad_clip_val: float = 5.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
fine_tune: bool = False,
|
| 220 |
quiet: bool = False,
|
| 221 |
):
|
|
|
|
| 298 |
target=r_unmasked_target,
|
| 299 |
ignore_index=IGNORE_INDEX,
|
| 300 |
top_k=topk,
|
|
|
|
|
|
|
| 301 |
)
|
| 302 |
output[f"{tag}/masked"] = accuracy(
|
| 303 |
preds=r_z_hat,
|
| 304 |
target=r_masked_target,
|
| 305 |
ignore_index=IGNORE_INDEX,
|
| 306 |
top_k=topk,
|
|
|
|
|
|
|
| 307 |
)
|
| 308 |
|
| 309 |
def train_loop(self, engine, batch):
|
|
|
|
| 322 |
n_batch = z.shape[0]
|
| 323 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 324 |
|
| 325 |
+
z_mask, mask = vn.add_noise(z, r)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
| 327 |
|
| 328 |
dtype = torch.bfloat16 if accel.amp else None
|
|
|
|
| 398 |
n_batch = z.shape[0]
|
| 399 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 400 |
|
| 401 |
+
z_mask, mask = vn.add_noise(z, r)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
| 403 |
|
| 404 |
z_hat = model(z_mask_latent, r)
|
|
|
|
| 512 |
)
|
| 513 |
|
| 514 |
def save_imputation(self, z: torch.Tensor):
|
| 515 |
+
n_prefix = int(z.shape[-1] * 0.25)
|
| 516 |
+
n_suffix = int(z.shape[-1] * 0.25)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
downsample_factor = None
|
| 518 |
|
| 519 |
vn = accel.unwrap(model)
|
|
|
|
| 576 |
|
| 577 |
n_batch = z.shape[0]
|
| 578 |
|
| 579 |
+
z_mask, mask = vn.add_noise(z, r)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
| 581 |
|
| 582 |
z_hat = model(z_mask_latent, r)
|
|
|
|
| 587 |
z_pred = vn.embedding.unflatten(z_pred, n_codebooks=vn.n_predict_codebooks)
|
| 588 |
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
| 589 |
|
|
|
|
| 590 |
generated = vn.to_signal(z_pred, codec)
|
| 591 |
reconstructed = vn.to_signal(z, codec)
|
| 592 |
masked = vn.to_signal(z_mask.squeeze(1), codec)
|
scripts/utils/vamp_folder.py
CHANGED
|
@@ -56,7 +56,7 @@ class CoarseCond:
|
|
| 56 |
|
| 57 |
def __call__(self, sig, interface):
|
| 58 |
n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
|
| 59 |
-
zv = interface.
|
| 60 |
n_conditioning_codebooks=n_conditioning_codebooks,
|
| 61 |
downsample_factor=self.downsample_factor,
|
| 62 |
)
|
|
@@ -113,7 +113,7 @@ def mask_ratio_1_step(ratio=1.0):
|
|
| 113 |
r = interface.coarse.invgamma(ratio).to(interface.device)
|
| 114 |
intensity = 1-r
|
| 115 |
|
| 116 |
-
zv = interface.
|
| 117 |
sig,
|
| 118 |
sample='argmax',
|
| 119 |
sampling_steps=1,
|
|
@@ -125,7 +125,7 @@ def mask_ratio_1_step(ratio=1.0):
|
|
| 125 |
|
| 126 |
def num_sampling_steps(num_steps=1):
|
| 127 |
def wrapper(sig, interface):
|
| 128 |
-
zv = interface.
|
| 129 |
sig,
|
| 130 |
downsample_factor=16,
|
| 131 |
sampling_steps=num_steps,
|
|
@@ -143,7 +143,7 @@ def beat_mask(ctx_time):
|
|
| 143 |
after_beat_s=ctx_time,
|
| 144 |
invert=True
|
| 145 |
)
|
| 146 |
-
zv = interface.
|
| 147 |
sig,
|
| 148 |
ext_mask=beat_mask,
|
| 149 |
)
|
|
@@ -154,7 +154,7 @@ def beat_mask(ctx_time):
|
|
| 154 |
|
| 155 |
def inpaint(ctx_time):
|
| 156 |
def wrapper(sig, interface):
|
| 157 |
-
zv = interface.
|
| 158 |
sig,
|
| 159 |
prefix_dur_s=ctx_time,
|
| 160 |
suffix_dur_s=ctx_time,
|
|
|
|
| 56 |
|
| 57 |
def __call__(self, sig, interface):
|
| 58 |
n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
|
| 59 |
+
zv = interface.coarse_vamp(sig,
|
| 60 |
n_conditioning_codebooks=n_conditioning_codebooks,
|
| 61 |
downsample_factor=self.downsample_factor,
|
| 62 |
)
|
|
|
|
| 113 |
r = interface.coarse.invgamma(ratio).to(interface.device)
|
| 114 |
intensity = 1-r
|
| 115 |
|
| 116 |
+
zv = interface.coarse_vamp(
|
| 117 |
sig,
|
| 118 |
sample='argmax',
|
| 119 |
sampling_steps=1,
|
|
|
|
| 125 |
|
| 126 |
def num_sampling_steps(num_steps=1):
|
| 127 |
def wrapper(sig, interface):
|
| 128 |
+
zv = interface.coarse_vamp(
|
| 129 |
sig,
|
| 130 |
downsample_factor=16,
|
| 131 |
sampling_steps=num_steps,
|
|
|
|
| 143 |
after_beat_s=ctx_time,
|
| 144 |
invert=True
|
| 145 |
)
|
| 146 |
+
zv = interface.coarse_vamp(
|
| 147 |
sig,
|
| 148 |
ext_mask=beat_mask,
|
| 149 |
)
|
|
|
|
| 154 |
|
| 155 |
def inpaint(ctx_time):
|
| 156 |
def wrapper(sig, interface):
|
| 157 |
+
zv = interface.coarse_vamp(
|
| 158 |
sig,
|
| 159 |
prefix_dur_s=ctx_time,
|
| 160 |
suffix_dur_s=ctx_time,
|
vampnet/interface.py
CHANGED
|
@@ -20,6 +20,14 @@ def signal_concat(
|
|
| 20 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
class Interface(torch.nn.Module):
|
| 24 |
def __init__(
|
| 25 |
self,
|
|
@@ -28,7 +36,7 @@ class Interface(torch.nn.Module):
|
|
| 28 |
codec_ckpt: str = None,
|
| 29 |
wavebeat_ckpt: str = None,
|
| 30 |
device: str = "cpu",
|
| 31 |
-
coarse_chunk_size_s: int =
|
| 32 |
coarse2fine_chunk_size_s: int = 3,
|
| 33 |
):
|
| 34 |
super().__init__()
|
|
@@ -141,7 +149,7 @@ class Interface(torch.nn.Module):
|
|
| 141 |
"""make a beat synced mask. that is, make a mask that
|
| 142 |
places 1s at and around the beat, and 0s everywhere else.
|
| 143 |
"""
|
| 144 |
-
assert
|
| 145 |
|
| 146 |
# get the beat times
|
| 147 |
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
|
@@ -242,7 +250,7 @@ class Interface(torch.nn.Module):
|
|
| 242 |
return fine_z[:, :, :length].clone()
|
| 243 |
|
| 244 |
|
| 245 |
-
def
|
| 246 |
self,
|
| 247 |
signal,
|
| 248 |
prefix_dur_s: float = 0.0,
|
|
@@ -471,7 +479,7 @@ class Interface(torch.nn.Module):
|
|
| 471 |
else:
|
| 472 |
ext_mask = None
|
| 473 |
|
| 474 |
-
out_z = self.
|
| 475 |
sig,
|
| 476 |
num_vamps=1,
|
| 477 |
swap_prefix_suffix=False,
|
|
@@ -520,7 +528,7 @@ class Interface(torch.nn.Module):
|
|
| 520 |
range_fn = range if not verbose else tqdm.trange
|
| 521 |
for i in range_fn(num_loops):
|
| 522 |
is_flipped = i % 2 == 0
|
| 523 |
-
vamped = self.
|
| 524 |
signal,
|
| 525 |
prefix_dur_s=prefix_dur_s,
|
| 526 |
suffix_dur_s=suffix_dur_s,
|
|
|
|
| 20 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
| 21 |
|
| 22 |
|
| 23 |
+
class SignalPrompt:
|
| 24 |
+
|
| 25 |
+
def __init__(self, signal: AudioSignal):
|
| 26 |
+
self.sig = signal
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
class Interface(torch.nn.Module):
|
| 32 |
def __init__(
|
| 33 |
self,
|
|
|
|
| 36 |
codec_ckpt: str = None,
|
| 37 |
wavebeat_ckpt: str = None,
|
| 38 |
device: str = "cpu",
|
| 39 |
+
coarse_chunk_size_s: int = 10,
|
| 40 |
coarse2fine_chunk_size_s: int = 3,
|
| 41 |
):
|
| 42 |
super().__init__()
|
|
|
|
| 149 |
"""make a beat synced mask. that is, make a mask that
|
| 150 |
places 1s at and around the beat, and 0s everywhere else.
|
| 151 |
"""
|
| 152 |
+
assert self.beat_tracker is not None, "No beat tracker loaded"
|
| 153 |
|
| 154 |
# get the beat times
|
| 155 |
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
|
|
|
| 250 |
return fine_z[:, :, :length].clone()
|
| 251 |
|
| 252 |
|
| 253 |
+
def coarse_vamp(
|
| 254 |
self,
|
| 255 |
signal,
|
| 256 |
prefix_dur_s: float = 0.0,
|
|
|
|
| 479 |
else:
|
| 480 |
ext_mask = None
|
| 481 |
|
| 482 |
+
out_z = self.coarse_vamp(
|
| 483 |
sig,
|
| 484 |
num_vamps=1,
|
| 485 |
swap_prefix_suffix=False,
|
|
|
|
| 528 |
range_fn = range if not verbose else tqdm.trange
|
| 529 |
for i in range_fn(num_loops):
|
| 530 |
is_flipped = i % 2 == 0
|
| 531 |
+
vamped = self.coarse_vamp(
|
| 532 |
signal,
|
| 533 |
prefix_dur_s=prefix_dur_s,
|
| 534 |
suffix_dur_s=suffix_dur_s,
|
vampnet/modules/base.py
CHANGED
|
@@ -10,6 +10,8 @@ import torch.nn.functional as F
|
|
| 10 |
from einops import rearrange
|
| 11 |
from tqdm import tqdm
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def log(t, eps=1e-20):
|
| 15 |
return torch.log(t + eps)
|
|
@@ -24,9 +26,6 @@ def gumbel_sample(t, temperature=1.0, dim=-1):
|
|
| 24 |
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
| 25 |
|
| 26 |
|
| 27 |
-
def scalar_to_batch_tensor(x, batch_size):
|
| 28 |
-
return torch.tensor(x).repeat(batch_size)
|
| 29 |
-
|
| 30 |
class VampBase(at.ml.BaseModel):
|
| 31 |
def forward(self, x: torch.Tensor, r: torch.Tensor):
|
| 32 |
raise NotImplementedError
|
|
@@ -150,6 +149,8 @@ class VampBase(at.ml.BaseModel):
|
|
| 150 |
z_hat = z_hat * mask + truth * (1 - mask)
|
| 151 |
|
| 152 |
z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
|
|
|
|
|
|
|
| 153 |
|
| 154 |
return z_hat
|
| 155 |
|
|
@@ -186,6 +187,9 @@ class VampBase(at.ml.BaseModel):
|
|
| 186 |
|
| 187 |
@torch.no_grad()
|
| 188 |
def to_signal(self, z, codec):
|
|
|
|
|
|
|
|
|
|
| 189 |
if z.ndim == 2:
|
| 190 |
z = self.embedding.unflatten(z)
|
| 191 |
assert z.ndim == 3
|
|
@@ -207,122 +211,7 @@ class VampBase(at.ml.BaseModel):
|
|
| 207 |
return signal
|
| 208 |
|
| 209 |
@torch.no_grad()
|
| 210 |
-
def sample(
|
| 211 |
-
if self.noise_mode == "mask":
|
| 212 |
-
return self.maskgit_sample(**kwargs)
|
| 213 |
-
else:
|
| 214 |
-
return self.paella_sample(**kwargs)
|
| 215 |
-
|
| 216 |
-
def paella_sample(
|
| 217 |
-
self,
|
| 218 |
-
codec,
|
| 219 |
-
time_steps: int = 400,
|
| 220 |
-
sampling_steps: int = 36,
|
| 221 |
-
start_tokens: Optional[torch.Tensor] = None,
|
| 222 |
-
mask: Optional[torch.Tensor] = None,
|
| 223 |
-
temperature: Union[float, Tuple[float, float]] = 0.8,
|
| 224 |
-
top_k: int = None,
|
| 225 |
-
sample: str = "gumbel",
|
| 226 |
-
renoise_mode: str = "start",
|
| 227 |
-
renoise_steps=None,
|
| 228 |
-
typical_filtering=True,
|
| 229 |
-
typical_mass=0.2,
|
| 230 |
-
typical_min_tokens=1,
|
| 231 |
-
return_signal=True,
|
| 232 |
-
):
|
| 233 |
-
|
| 234 |
-
r = torch.linspace(0, 1, sampling_steps + 1)[:-1][:, None].to(self.device)
|
| 235 |
-
if renoise_steps == None:
|
| 236 |
-
renoise_steps = sampling_steps - 1
|
| 237 |
-
|
| 238 |
-
if isinstance(temperature, float):
|
| 239 |
-
temperature = torch.tensor(temperature).repeat(sampling_steps)
|
| 240 |
-
elif isinstance(temperature, tuple):
|
| 241 |
-
assert len(temperature) == 2
|
| 242 |
-
l, h = temperature
|
| 243 |
-
temperature = torch.linspace(l, h, sampling_steps)
|
| 244 |
-
else:
|
| 245 |
-
raise TypeError(f"invalid type for temperature")
|
| 246 |
-
|
| 247 |
-
if self.n_conditioning_codebooks > 0:
|
| 248 |
-
assert (
|
| 249 |
-
start_tokens is not None
|
| 250 |
-
), "must provide start_tokens if n_conditioning_codebooks > 0"
|
| 251 |
-
|
| 252 |
-
if start_tokens is None:
|
| 253 |
-
if self.noise_mode == "noise":
|
| 254 |
-
z = torch.randint(
|
| 255 |
-
0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
|
| 256 |
-
).to(self.device)
|
| 257 |
-
elif self.noise_mode == "mask":
|
| 258 |
-
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
|
| 259 |
-
else:
|
| 260 |
-
z = start_tokens
|
| 261 |
-
assert (
|
| 262 |
-
z.ndim == 3
|
| 263 |
-
), f"start_tokens must be shape (batch, n_codebooks, seq_len), got {z.shape}"
|
| 264 |
-
assert z.shape[0] == 1, f"batch size must be 1"
|
| 265 |
-
|
| 266 |
-
if mask is None:
|
| 267 |
-
mask = torch.ones(z.shape[0], z.shape[-1]).to(self.device).int()
|
| 268 |
-
mask = mask[:, None, :]
|
| 269 |
-
mask = mask.repeat(1, z.shape[1], 1)
|
| 270 |
-
|
| 271 |
-
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
z_true = z.clone()
|
| 275 |
-
|
| 276 |
-
z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
|
| 277 |
-
z_init = z.clone()
|
| 278 |
-
for i, tmpt in enumerate(temperature):
|
| 279 |
-
if renoise_mode == "prev":
|
| 280 |
-
z_prev = z.clone()
|
| 281 |
-
|
| 282 |
-
latents = self.embedding.from_codes(z, codec)
|
| 283 |
-
logits = self.forward(latents, r[i])
|
| 284 |
-
|
| 285 |
-
# for mask mode
|
| 286 |
-
logits = self.add_truth_to_logits(z_true, logits, mask)
|
| 287 |
-
|
| 288 |
-
# Apply topk sampling
|
| 289 |
-
logits = logits.permute(0, 2, 1)
|
| 290 |
-
|
| 291 |
-
z = self.sample_from_logits(
|
| 292 |
-
logits,
|
| 293 |
-
top_k=top_k,
|
| 294 |
-
temperature=tmpt,
|
| 295 |
-
sample=sample,
|
| 296 |
-
typical_filtering=typical_filtering,
|
| 297 |
-
typical_mass=typical_mass,
|
| 298 |
-
typical_min_tokens=typical_min_tokens,
|
| 299 |
-
)
|
| 300 |
-
|
| 301 |
-
# add back in conditioning codebooks
|
| 302 |
-
z = self.embedding.unflatten(z, n_codebooks=self.n_predict_codebooks)
|
| 303 |
-
z = torch.cat(
|
| 304 |
-
[z_init[:, : self.n_conditioning_codebooks, :], z], dim=1
|
| 305 |
-
).int()
|
| 306 |
-
|
| 307 |
-
if i < renoise_steps:
|
| 308 |
-
if renoise_mode == "prev":
|
| 309 |
-
z, _ = self.add_noise(z, r[i + 1], random_x=z_prev)
|
| 310 |
-
elif renoise_mode == "start":
|
| 311 |
-
z, _ = self.add_noise(z, r[i + 1], random_x=z_init)
|
| 312 |
-
elif renoise_mode == "rand":
|
| 313 |
-
z, _ = self.add_noise(z, r[i + 1])
|
| 314 |
-
else:
|
| 315 |
-
raise ValueError(f"Invalid renoise_mode: {renoise_mode}")
|
| 316 |
-
|
| 317 |
-
if mask is not None:
|
| 318 |
-
z = start_tokens * (1 - mask) + z * mask
|
| 319 |
-
|
| 320 |
-
if return_signal:
|
| 321 |
-
return self.to_signal(z, codec)
|
| 322 |
-
else:
|
| 323 |
-
return z
|
| 324 |
-
|
| 325 |
-
def maskgit_sample(
|
| 326 |
self,
|
| 327 |
codec,
|
| 328 |
time_steps: int = 300,
|
|
|
|
| 10 |
from einops import rearrange
|
| 11 |
from tqdm import tqdm
|
| 12 |
|
| 13 |
+
from ..util import scalar_to_batch_tensor
|
| 14 |
+
|
| 15 |
|
| 16 |
def log(t, eps=1e-20):
|
| 17 |
return torch.log(t + eps)
|
|
|
|
| 26 |
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
class VampBase(at.ml.BaseModel):
|
| 30 |
def forward(self, x: torch.Tensor, r: torch.Tensor):
|
| 31 |
raise NotImplementedError
|
|
|
|
| 149 |
z_hat = z_hat * mask + truth * (1 - mask)
|
| 150 |
|
| 151 |
z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"invalid noise mode for adding truth to logits {self.noise_mode}")
|
| 154 |
|
| 155 |
return z_hat
|
| 156 |
|
|
|
|
| 187 |
|
| 188 |
@torch.no_grad()
|
| 189 |
def to_signal(self, z, codec):
|
| 190 |
+
"""
|
| 191 |
+
convert a sequence of latents to a signal.
|
| 192 |
+
"""
|
| 193 |
if z.ndim == 2:
|
| 194 |
z = self.embedding.unflatten(z)
|
| 195 |
assert z.ndim == 3
|
|
|
|
| 211 |
return signal
|
| 212 |
|
| 213 |
@torch.no_grad()
|
| 214 |
+
def sample(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
self,
|
| 216 |
codec,
|
| 217 |
time_steps: int = 300,
|
vampnet/modules/layers.py
CHANGED
|
@@ -132,6 +132,11 @@ class CodebookEmbedding(nn.Module):
|
|
| 132 |
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
| 133 |
|
| 134 |
def from_codes(self, codes: torch.Tensor, codec):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
n_codebooks = codes.shape[1]
|
| 136 |
latent = []
|
| 137 |
for i in range(n_codebooks):
|
|
@@ -151,14 +156,23 @@ class CodebookEmbedding(nn.Module):
|
|
| 151 |
return latent
|
| 152 |
|
| 153 |
def forward(self, latents: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
| 154 |
x = self.out_proj(latents)
|
| 155 |
return x
|
| 156 |
|
| 157 |
def flatten(self, tokens: torch.Tensor, n_codebooks: int = None):
|
|
|
|
|
|
|
|
|
|
| 158 |
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
| 159 |
return rearrange(tokens, "b c t -> b (t c)", c=n_c)
|
| 160 |
|
| 161 |
def unflatten(self, flat_tokens: torch.Tensor, n_codebooks: int = None):
|
|
|
|
|
|
|
|
|
|
| 162 |
nb, nt = flat_tokens.shape
|
| 163 |
|
| 164 |
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
|
|
|
| 132 |
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
| 133 |
|
| 134 |
def from_codes(self, codes: torch.Tensor, codec):
|
| 135 |
+
"""
|
| 136 |
+
get a sequence of continuous embeddings from a sequence of discrete codes.
|
| 137 |
+
unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
|
| 138 |
+
necessary for the language model, like <MASK>.
|
| 139 |
+
"""
|
| 140 |
n_codebooks = codes.shape[1]
|
| 141 |
latent = []
|
| 142 |
for i in range(n_codebooks):
|
|
|
|
| 156 |
return latent
|
| 157 |
|
| 158 |
def forward(self, latents: torch.Tensor):
|
| 159 |
+
"""
|
| 160 |
+
project a sequence of latents to a sequence of embeddings
|
| 161 |
+
"""
|
| 162 |
x = self.out_proj(latents)
|
| 163 |
return x
|
| 164 |
|
| 165 |
def flatten(self, tokens: torch.Tensor, n_codebooks: int = None):
|
| 166 |
+
"""
|
| 167 |
+
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
|
| 168 |
+
"""
|
| 169 |
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
| 170 |
return rearrange(tokens, "b c t -> b (t c)", c=n_c)
|
| 171 |
|
| 172 |
def unflatten(self, flat_tokens: torch.Tensor, n_codebooks: int = None):
|
| 173 |
+
"""
|
| 174 |
+
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
|
| 175 |
+
"""
|
| 176 |
nb, nt = flat_tokens.shape
|
| 177 |
|
| 178 |
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
vampnet/signal.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
from .util import scalar_to_batch_tensor
|
| 5 |
+
|
vampnet/util.py
CHANGED
|
@@ -1,40 +1,9 @@
|
|
| 1 |
import tqdm
|
| 2 |
-
# import pathos
|
| 3 |
|
| 4 |
-
|
| 5 |
-
"""
|
| 6 |
-
Equivalent of `list(map(fn, *iterables))`
|
| 7 |
-
driven by `concurrent.futures.ProcessPoolExecutor`.
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
tqdm_class : optional
|
| 12 |
-
`tqdm` class to use for bars [default: tqdm.auto.tqdm].
|
| 13 |
-
max_workers : int, optional
|
| 14 |
-
Maximum number of workers to spawn; passed to
|
| 15 |
-
`concurrent.futures.ProcessPoolExecutor.__init__`.
|
| 16 |
-
[default: min(32, cpu_count() + 4)].
|
| 17 |
-
chunksize : int, optional
|
| 18 |
-
Size of chunks sent to worker processes; passed to
|
| 19 |
-
`concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
|
| 20 |
-
lock_name : str, optional
|
| 21 |
-
Member of `tqdm_class.get_lock()` to use [default: mp_lock].
|
| 22 |
-
"""
|
| 23 |
-
from concurrent.futures import ProcessPoolExecutor
|
| 24 |
-
if iterables and "chunksize" not in tqdm_kwargs:
|
| 25 |
-
# default `chunksize=1` has poor performance for large iterables
|
| 26 |
-
# (most time spent dispatching items to workers).
|
| 27 |
-
longest_iterable_len = max(map(length_hint, iterables))
|
| 28 |
-
if longest_iterable_len > 1000:
|
| 29 |
-
from warnings import warn
|
| 30 |
-
warn("Iterable length %d > 1000 but `chunksize` is not set."
|
| 31 |
-
" This may seriously degrade multiprocess performance."
|
| 32 |
-
" Set `chunksize=1` or more." % longest_iterable_len,
|
| 33 |
-
TqdmWarning, stacklevel=2)
|
| 34 |
-
if "lock_name" not in tqdm_kwargs:
|
| 35 |
-
tqdm_kwargs = tqdm_kwargs.copy()
|
| 36 |
-
tqdm_kwargs["lock_name"] = "mp_lock"
|
| 37 |
-
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
| 38 |
|
| 39 |
|
| 40 |
def parallelize(
|
|
|
|
| 1 |
import tqdm
|
|
|
|
| 2 |
|
| 3 |
+
import torch
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
def scalar_to_batch_tensor(x, batch_size):
|
| 6 |
+
return torch.tensor(x).repeat(batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def parallelize(
|