diff --git "a/notebooks/diffusion_test.ipynb" "b/notebooks/diffusion_test.ipynb" deleted file mode 100644--- "a/notebooks/diffusion_test.ipynb" +++ /dev/null @@ -1,876 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "4c52cc1c-91f1-4b79-924b-041d2929ef7b", - "metadata": {}, - "outputs": [], - "source": [ - "from audio_diffusion_pytorch import AudioDiffusionModel, Sampler, Schedule, VSampler, LinearSchedule, AudioDiffusionAE\n", - "import torch\n", - "from torch import Tensor, nn, optim\n", - "from IPython.display import Audio\n", - "import pytorch_lightning as pl\n", - "from torch.utils.data import random_split, DataLoader, Dataset\n", - "\n", - "from einops import rearrange\n", - "from ema_pytorch import EMA\n", - "from pytorch_lightning import Callback, Trainer\n", - "from typing import Any, Callable, Dict, List, Optional, Sequence, Union\n", - "from pytorch_lightning.loggers import WandbLogger\n", - "import wandb\n", - "import torchaudio\n", - "import librosa\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "a005011f-3019-4d34-bdf2-9a00e5480282", - "metadata": {}, - "outputs": [], - "source": [ - "# device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "6349ed8e-f418-436f-860e-62a51e48f79a", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmattricesound\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] - }, - { - "data": { - "text/html": [ - "Tracking run with wandb version 0.13.7" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Run data is saved locally in ./wandb/run-20230107_213018-192gzo2n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Syncing run laced-bush-17 to Weights & Biases (docs)
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "wandb_logger = WandbLogger(project=\"RemFX\", save_dir=\"./\")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "1b689f18-375f-4b40-9ddc-a4ced6a5e5e4", - "metadata": {}, - "outputs": [], - "source": [ - "#AudioDiffusionModel\n", - "#AudioDiffusionAE\n", - "model = AudioDiffusionModel(in_channels=1, \n", - " patch_size=1,\n", - " multipliers=[1, 2, 4, 4, 4, 4, 4],\n", - " factors=[2, 2, 2, 2, 2, 2],\n", - " num_blocks=[2, 2, 2, 2, 2, 2],\n", - " attentions=[0, 0, 0, 0, 0, 0]\n", - " )\n", - "\n", - "\n", - "# model = model.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "950711d4-9e8a-4af1-8d56-204e4ce0a19b", - "metadata": {}, - "outputs": [], - "source": [ - "class Model(pl.LightningModule):\n", - " def __init__(\n", - " self,\n", - " lr: float,\n", - " lr_eps: float,\n", - " lr_beta1: float,\n", - " lr_beta2: float,\n", - " lr_weight_decay: float,\n", - " ema_beta: float,\n", - " ema_power: float,\n", - " model: nn.Module,\n", - " ):\n", - " super().__init__()\n", - " self.lr = lr\n", - " self.lr_eps = lr_eps\n", - " self.lr_beta1 = lr_beta1\n", - " self.lr_beta2 = lr_beta2\n", - " self.lr_weight_decay = lr_weight_decay\n", - " self.model = model\n", - " self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)\n", - "\n", - " @property\n", - " def device(self):\n", - " return next(self.model.parameters()).device\n", - "\n", - " def configure_optimizers(self):\n", - " optimizer = torch.optim.AdamW(\n", - " list(self.parameters()),\n", - " lr=self.lr,\n", - " betas=(self.lr_beta1, self.lr_beta2),\n", - " eps=self.lr_eps,\n", - " weight_decay=self.lr_weight_decay,\n", - " )\n", - " return optimizer\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " waveforms = batch\n", - " loss = self.model(waveforms)\n", - " self.log(\"train_loss\", loss)\n", - " self.model_ema.update()\n", - " self.log(\"ema_decay\", self.model_ema.get_current_decay())\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " waveforms = batch\n", - " loss = self.model_ema(waveforms)\n", - " self.log(\"valid_loss\", loss)\n", - " return loss" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ce9b20b-d163-425a-a92d-8ddb1a92b905", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "cfa42700-f190-485d-84b9-d9203f8275d7", - "metadata": {}, - "outputs": [], - "source": [ - "params = {\n", - " \"lr\": 1e-4,\n", - " \"lr_beta1\": 0.95,\n", - " \"lr_beta2\": 0.999,\n", - " \"lr_eps\": 1e-6,\n", - " \"lr_weight_decay\": 1e-3,\n", - " \"ema_beta\": 0.995,\n", - " \"ema_power\": 0.7,\n", - " \"model\": model \n", - "}\n", - "diffModel = Model(**params)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "aa4029a4-efd8-4922-a863-cf7677e86c05", - "metadata": {}, - "outputs": [], - "source": [ - "fs = 22050\n", - "t = 2 ** 18 / fs # 12 seconds\n", - "\n", - "class SinDataset(Dataset):\n", - " def __init__(self, num):\n", - " self.n = num\n", - " self.samples = torch.arange(t * fs) / fs\n", - " def __len__(self):\n", - " return self.n\n", - " def __getitem__(self, i): \n", - " f = 6000 * torch.rand(1) + 300\n", - " signal = torch.sin(2 * torch.pi * (f*2) * self.samples).unsqueeze(0)\n", - " return signal" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "ae57ad99-fdaf-4720-91b0-ce9338e6a811", - "metadata": {}, - "outputs": [], - "source": [ - "data = DataLoader(SinDataset(1000), batch_size=2)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7b131b37-485f-4d4f-8616-6e7afe25beb9", - "metadata": {}, - "outputs": [], - "source": [ - "val_data = DataLoader(SinDataset(1000), batch_size=2)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "4d98c1a0-1763-4d0b-be1d-e84ace68bebb", - "metadata": {}, - "outputs": [], - "source": [ - "dataiter = iter(data)\n", - "x = next(dataiter)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "c3259082-20d5-415c-8a88-3b97af6615ee", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 1, 262144])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "d1ec36ea-0f9c-49f6-8f24-a479084ea230", - "metadata": {}, - "outputs": [], - "source": [ - "class SampleLogger(Callback):\n", - " def __init__(\n", - " self,\n", - " num_items: int,\n", - " channels: int,\n", - " sampling_rate: int,\n", - " length: int,\n", - " sampling_steps: List[int],\n", - " diffusion_schedule: Schedule,\n", - " diffusion_sampler: Sampler,\n", - " use_ema_model: bool,\n", - " ) -> None:\n", - " self.num_items = num_items\n", - " self.channels = channels\n", - " self.sampling_rate = sampling_rate\n", - " self.length = length\n", - " self.sampling_steps = sampling_steps\n", - " self.diffusion_schedule = diffusion_schedule\n", - " self.diffusion_sampler = diffusion_sampler\n", - " self.use_ema_model = use_ema_model\n", - "\n", - " self.log_next = False\n", - "\n", - " def on_validation_epoch_start(self, trainer, pl_module):\n", - " self.log_next = True\n", - "\n", - " def on_validation_batch_start(\n", - " self, trainer, pl_module, batch, batch_idx, dataloader_idx\n", - " ):\n", - " if self.log_next:\n", - " self.log_sample(trainer, pl_module, batch)\n", - " self.log_next = False\n", - "\n", - " @torch.no_grad()\n", - " def log_sample(self, trainer, pl_module, batch):\n", - " is_train = pl_module.training\n", - " if is_train:\n", - " pl_module.eval()\n", - "\n", - " wandb_logger = get_wandb_logger(trainer).experiment\n", - "\n", - " diffusion_model = pl_module.model\n", - " if self.use_ema_model:\n", - " diffusion_model = pl_module.model_ema.ema_model\n", - " # Get start diffusion noise\n", - " noise = torch.randn(\n", - " (self.num_items, self.channels, self.length), device=pl_module.device\n", - " )\n", - "\n", - " for steps in self.sampling_steps:\n", - " samples = diffusion_model.sample(\n", - " noise=noise,\n", - " sampler=self.diffusion_sampler,\n", - " sigma_schedule=self.diffusion_schedule,\n", - " num_steps=steps,\n", - " )\n", - " log_wandb_audio_batch(\n", - " logger=wandb_logger,\n", - " id=\"sample\",\n", - " samples=samples,\n", - " sampling_rate=self.sampling_rate,\n", - " caption=f\"Sampled in {steps} steps\",\n", - " )\n", - " # log_wandb_audio_spectrogram(\n", - " # logger=wandb_logger,\n", - " # id=\"sample\",\n", - " # samples=samples,\n", - " # sampling_rate=self.sampling_rate,\n", - " # caption=f\"Sampled in {steps} steps\",\n", - " # )\n", - "\n", - " if is_train:\n", - " pl_module.train()\n", - "\n", - "def get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]:\n", - " \"\"\"Safely get Weights&Biases logger from Trainer.\"\"\"\n", - "\n", - " if isinstance(trainer.logger, WandbLogger):\n", - " return trainer.logger\n", - "\n", - " if isinstance(trainer.logger, LoggerCollection):\n", - " for logger in trainer.logger:\n", - " if isinstance(logger, WandbLogger):\n", - " return logger\n", - "\n", - " print(\"WandbLogger not found.\")\n", - " return None\n", - "\n", - "\n", - "def log_wandb_audio_batch(\n", - " logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = \"\"\n", - "):\n", - " num_items = samples.shape[0]\n", - " samples = rearrange(samples, \"b c t -> b t c\").detach().cpu().numpy()\n", - " logger.log(\n", - " {\n", - " f\"sample_{idx}_{id}\": wandb.Audio(\n", - " samples[idx],\n", - " caption=caption,\n", - " sample_rate=sampling_rate,\n", - " )\n", - " for idx in range(num_items)\n", - " }\n", - " )\n", - "\n", - "\n", - "def log_wandb_audio_spectrogram(\n", - " logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = \"\"\n", - "):\n", - " num_items = samples.shape[0]\n", - " samples = samples.detach().cpu()\n", - " transform = torchaudio.transforms.MelSpectrogram(\n", - " sample_rate=sampling_rate,\n", - " n_fft=1024,\n", - " hop_length=512,\n", - " n_mels=80,\n", - " center=True,\n", - " norm=\"slaney\",\n", - " )\n", - "\n", - " def get_spectrogram_image(x):\n", - " spectrogram = transform(x[0])\n", - " image = librosa.power_to_db(spectrogram)\n", - " trace = [go.Heatmap(z=image, colorscale=\"viridis\")]\n", - " layout = go.Layout(\n", - " yaxis=dict(title=\"Mel Bin (Log Frequency)\"),\n", - " xaxis=dict(title=\"Frame\"),\n", - " title_text=caption,\n", - " title_font_size=10,\n", - " )\n", - " fig = go.Figure(data=trace, layout=layout)\n", - " return fig\n", - "\n", - " logger.log(\n", - " {\n", - " f\"mel_spectrogram_{idx}_{id}\": get_spectrogram_image(samples[idx])\n", - " for idx in range(num_items)\n", - " }\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "27c038a6-38f1-4a61-a472-2591ae39af3b", - "metadata": {}, - "outputs": [], - "source": [ - "vsampler = VSampler()\n", - "linear_schedule = LinearSchedule()\n", - "samples_config = {\n", - " \"num_items\": 3,\n", - " \"channels\": 1,\n", - " \"sampling_rate\": fs,\n", - " \"sampling_steps\": [3,5,10,25,50,100],\n", - " \"use_ema_model\": True,\n", - " \"diffusion_sampler\": vsampler,\n", - " \"length\": 262144,\n", - " \"diffusion_schedule\": linear_schedule\n", - "}\n", - "s = SampleLogger(**samples_config)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ffe84ea2-6e3f-42f0-a261-57649574a601", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "8f8f3cda-da27-477c-b553-bca4eaad69ea", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n" - ] - } - ], - "source": [ - "trainer = pl.Trainer(limit_train_batches=100, max_epochs=100, accelerator='gpu', devices=[1], callbacks=[s], logger=wandb_logger)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47b8760a-8ee3-4212-8817-a804fd02fade", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", - "\n", - " | Name | Type | Params\n", - "--------------------------------------------------\n", - "0 | model | AudioDiffusionModel | 74.3 M\n", - "1 | model_ema | EMA | 148 M \n", - "--------------------------------------------------\n", - "74.3 M Trainable params\n", - "74.3 M Non-trainable params\n", - "148 M Total params\n", - "594.631 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", - " rank_zero_warn(\n", - "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", - " rank_zero_warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5327d73bb6114877adb4e9f991058eea", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c6e6b42717824054b576e47f92878ef5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "trainer.fit(model=diffModel, train_dataloaders=data, val_dataloaders=val_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1f64d981-c9dc-4afa-b783-d017f99633da", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "53bba197-83eb-40a2-b748-a4c25e628356", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "49db25f0-8bda-4693-9872-cbf24c40b575", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "29ed502f-2daf-4210-81ff-a90ade519086", - "metadata": {}, - "outputs": [], - "source": [ - "# Old code below" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "bd8a1cb4-42b5-43bc-9a12-f594ce069b33", - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'device' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn [14], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m signal2 \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpi \u001b[38;5;241m*\u001b[39m (f\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m*\u001b[39m samples)\n\u001b[1;32m 11\u001b[0m stacked_signal \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack((signal1, signal2))\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m stacked_signal \u001b[38;5;241m=\u001b[39m stacked_signal\u001b[38;5;241m.\u001b[39mto(\u001b[43mdevice\u001b[49m)\n\u001b[1;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m model(stacked_signal)\n\u001b[1;32m 14\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward() \n", - "\u001b[0;31mNameError\u001b[0m: name 'device' is not defined" - ] - } - ], - "source": [ - "fs = 22050\n", - "t = 2 ** 18 / 22050\n", - "samples = torch.arange(t * fs) / fs\n", - "\n", - "for i in range(300, 8000):\n", - " f = i\n", - " # Create 2 sine waves (one at f=step, other is octave up) \n", - " # There is aliasing at higher freq, but since it is sinusoids, that doesn't matter too much\n", - " signal1 = torch.sin(2 * torch.pi * f * samples)\n", - " signal2 = torch.sin(2 * torch.pi * (f*2) * samples)\n", - " stacked_signal = torch.stack((signal1, signal2)).unsqueeze(1)\n", - " stacked_signal = stacked_signal.to(device)\n", - " loss = model(stacked_signal)\n", - " loss.backward() \n", - " if i % 10 == 0:\n", - " print(\"Step\", i)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "71d17c51-842c-40a1-81a1-a53bf358bc8a", - "metadata": {}, - "outputs": [], - "source": [ - "# Sample 2 sources given start noise\n", - "noise = torch.randn(2, 1, 2 ** 18)\n", - "noise = noise.to(device)\n", - "sampled = model.sample(\n", - " noise=noise,\n", - " num_steps=10 # Suggested range: 2-50\n", - ") # [2, 1, 2 ** 18]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "59d71efa-05ac-4545-84da-8c09c033dfd7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "z = sampled[1]\n", - "Audio(z.cpu(), rate=22050)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "81eddd71-bba7-4c62-8d50-900b295bb2f8", - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'z' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn [12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mz\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", - "\u001b[0;31mNameError\u001b[0m: name 'z' is not defined" - ] - } - ], - "source": [ - "z.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8a3f582f-a956-4326-872b-416cc13b77ee", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}