diff --git "a/notebooks/Experiments.ipynb" "b/notebooks/Experiments.ipynb" deleted file mode 100644--- "a/notebooks/Experiments.ipynb" +++ /dev/null @@ -1,1568 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 3, - "id": "c2807819", - "metadata": {}, - "outputs": [], - "source": [ - "from audio_diffusion_pytorch import AudioDiffusionModel\n", - "import torch\n", - "from tqdm import tqdm\n", - "from IPython.display import Audio\n", - "from pathlib import Path\n", - "import torchaudio\n", - "import torchaudio.transforms as T\n", - "import pytorch_lightning as pl\n", - "from torch.utils.data import random_split, DataLoader, Dataset\n", - "import torch.nn.functional as F\n", - "import wandb\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "cdd08230-c057-4a6e-83b9-435b2c0fbaaf", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'1.13.0+cu117'" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.__version__" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "469edd04", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmattricesound\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] - }, - { - "data": { - "text/html": [ - "Tracking run with wandb version 0.13.6" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Run data is saved locally in /home/jovyan/RemFx/wandb/run-20221209_160820-9wzgwfl3" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Syncing run fast-snowflake-6 to Weights & Biases (docs)
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "wandb.init(project=\"RemFX\", entity=\"mattricesound\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8d7eacfc", - "metadata": {}, - "outputs": [], - "source": [ - "SAMPLE_RATE = 22050\n", - "LENGTH = 2**17#round(5 * SAMPLE_RATE) 6 seconds" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "d8f78b50-b8f5-4008-b986-fb02590a9cd1", - "metadata": {}, - "outputs": [], - "source": [ - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "cdc0fb64", - "metadata": {}, - "outputs": [], - "source": [ - "class GuitarDataset(Dataset):\n", - " def __init__(self, root, length=LENGTH):\n", - " self.files = list(Path().glob(f\"{root}/**/*.wav\"))\n", - " self.resampler = T.Resample(48000, SAMPLE_RATE)\n", - " \n", - " def __len__(self):\n", - " return len(self.files)\n", - " \n", - " def __getitem__(self, idx):\n", - " x, sr = torchaudio.load(self.files[idx])\n", - "# x = x.view() # Duplicate channel\n", - " resampled_x = self.resampler(x)\n", - " if resampled_x.shape[1] < LENGTH:\n", - " resampled_x = F.pad(resampled_x, (0, LENGTH - resampled_x.shape[1]))\n", - " elif resampled_x.shape[1] > LENGTH:\n", - " resampled_x = resampled_x[:, :LENGTH]\n", - " return resampled_x.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "148c2a96", - "metadata": {}, - "outputs": [], - "source": [ - "g = GuitarDataset(Path(\"Clean\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "670c94a5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 131072])\n" - ] - } - ], - "source": [ - "x = g[10]\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e1c83600", - "metadata": {}, - "outputs": [], - "source": [ - "data = DataLoader(GuitarDataset(Path(\"Clean\")), batch_size=32)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "4d46f992", - "metadata": {}, - "outputs": [], - "source": [ - "dataiter = iter(data)\n", - "x = next(dataiter)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "1103e520", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 131072])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "6b0f1575", - "metadata": {}, - "outputs": [], - "source": [ - "# wandb.log({\"Audio\": wandb.Audio(x[0].view(-1).numpy(), sample_rate=SAMPLE_RATE)})" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "314fd8af-a813-436e-9ca5-29dc3a5ad460", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "eff19abd-304c-449e-9fb5-4e9ce4d4b19c", - "metadata": {}, - "outputs": [], - "source": [ - "model = AudioDiffusionModel(in_channels=1, \n", - " patch_size=1,\n", - " multipliers=[1, 2, 4, 4, 4, 4, 4],\n", - " factors=[2, 2, 2, 2, 2, 2],\n", - " num_blocks=[2, 2, 2, 2, 2, 2],\n", - " attentions=[0, 0, 0, 0, 0, 0]\n", - " )\n", - "model = model.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "75dd6e95-5e31-43f5-a0f8-05c7e13e7a14", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "300\n", - "310\n", - "320\n", - "330\n", - "340\n", - "350\n", - "360\n", - "370\n", - "380\n", - "390\n", - "400\n", - "410\n", - "420\n", - "430\n", - "440\n", - "450\n", - "460\n", - "470\n", - "480\n", - "490\n", - "500\n", - "510\n", - "520\n", - "530\n", - "540\n", - "550\n", - "560\n", - "570\n", - "580\n", - "590\n", - "600\n", - "610\n", - "620\n", - "630\n", - "640\n", - "650\n", - "660\n", - "670\n", - "680\n", - "690\n", - "700\n", - "710\n", - "720\n", - "730\n", - "740\n", - "750\n", - "760\n", - "770\n", - "780\n", - "790\n", - "800\n", - "810\n", - "820\n", - "830\n", - "840\n", - "850\n", - "860\n", - "870\n", - "880\n", - "890\n", - "900\n", - "910\n", - "920\n", - "930\n", - "940\n", - "950\n", - "960\n", - "970\n", - "980\n", - "990\n", - "1000\n", - "1010\n", - "1020\n", - "1030\n", - "1040\n", - "1050\n", - "1060\n", - "1070\n", - "1080\n", - "1090\n", - "1100\n", - "1110\n", - "1120\n", - "1130\n", - "1140\n", - "1150\n", - "1160\n", - "1170\n", - "1180\n", - "1190\n", - "1200\n", - "1210\n", - "1220\n", - "1230\n", - "1240\n", - "1250\n", - "1260\n", - "1270\n", - "1280\n", - "1290\n", - "1300\n", - "1310\n", - "1320\n", - "1330\n", - "1340\n", - "1350\n", - "1360\n", - "1370\n", - "1380\n", - "1390\n", - "1400\n", - "1410\n", - "1420\n", - "1430\n", - "1440\n", - "1450\n", - "1460\n", - "1470\n", - "1480\n", - "1490\n", - "1500\n", - "1510\n", - "1520\n", - "1530\n", - "1540\n", - "1550\n", - "1560\n", - "1570\n", - "1580\n", - "1590\n", - "1600\n", - "1610\n", - "1620\n", - "1630\n", - "1640\n", - "1650\n", - "1660\n", - "1670\n", - "1680\n", - "1690\n", - "1700\n", - "1710\n", - "1720\n", - "1730\n", - "1740\n", - "1750\n", - "1760\n", - "1770\n", - "1780\n", - "1790\n", - "1800\n", - "1810\n", - "1820\n", - "1830\n", - "1840\n", - "1850\n", - "1860\n", - "1870\n", - "1880\n", - "1890\n", - "1900\n", - "1910\n", - "1920\n", - "1930\n", - "1940\n", - "1950\n", - "1960\n", - "1970\n", - "1980\n", - "1990\n", - "2000\n", - "2010\n", - "2020\n", - "2030\n", - "2040\n", - "2050\n", - "2060\n", - "2070\n", - "2080\n", - "2090\n", - "2100\n", - "2110\n", - "2120\n", - "2130\n", - "2140\n", - "2150\n", - "2160\n", - "2170\n", - "2180\n", - "2190\n", - "2200\n", - "2210\n", - "2220\n", - "2230\n", - "2240\n", - "2250\n", - "2260\n", - "2270\n", - "2280\n", - "2290\n", - "2300\n", - "2310\n", - "2320\n", - "2330\n", - "2340\n", - "2350\n", - "2360\n", - "2370\n", - "2380\n", - "2390\n", - "2400\n", - "2410\n", - "2420\n", - "2430\n", - "2440\n", - "2450\n", - "2460\n", - "2470\n", - "2480\n", - "2490\n", - "2500\n", - "2510\n", - "2520\n", - "2530\n", - "2540\n", - "2550\n", - "2560\n", - "2570\n", - "2580\n", - "2590\n", - "2600\n", - "2610\n", - "2620\n", - "2630\n", - "2640\n", - "2650\n", - "2660\n", - "2670\n", - "2680\n", - "2690\n", - "2700\n", - "2710\n", - "2720\n", - "2730\n", - "2740\n", - "2750\n", - "2760\n", - "2770\n", - "2780\n", - "2790\n", - "2800\n", - "2810\n", - "2820\n", - "2830\n", - "2840\n", - "2850\n", - "2860\n", - "2870\n", - "2880\n", - "2890\n", - "2900\n", - "2910\n", - "2920\n", - "2930\n", - "2940\n", - "2950\n", - "2960\n", - "2970\n", - "2980\n", - "2990\n", - "3000\n", - "3010\n", - "3020\n", - "3030\n", - "3040\n", - "3050\n", - "3060\n", - "3070\n", - "3080\n", - "3090\n", - "3100\n", - "3110\n", - "3120\n", - "3130\n", - "3140\n", - "3150\n", - "3160\n", - "3170\n", - "3180\n", - "3190\n", - "3200\n", - "3210\n", - "3220\n", - "3230\n", - "3240\n", - "3250\n", - "3260\n", - "3270\n", - "3280\n", - "3290\n", - "3300\n", - "3310\n", - "3320\n", - "3330\n", - "3340\n", - "3350\n", - "3360\n", - "3370\n", - "3380\n", - "3390\n", - "3400\n", - "3410\n", - "3420\n", - "3430\n", - "3440\n", - "3450\n", - "3460\n", - "3470\n", - "3480\n", - "3490\n", - "3500\n", - "3510\n", - "3520\n", - "3530\n", - "3540\n", - "3550\n", - "3560\n", - "3570\n", - "3580\n", - "3590\n", - "3600\n", - "3610\n", - "3620\n", - "3630\n", - "3640\n", - "3650\n", - "3660\n", - "3670\n", - "3680\n", - "3690\n", - "3700\n", - "3710\n", - "3720\n", - "3730\n", - "3740\n", - "3750\n", - "3760\n", - "3770\n", - "3780\n", - "3790\n", - "3800\n", - "3810\n", - "3820\n", - "3830\n", - "3840\n", - "3850\n", - "3860\n", - "3870\n", - "3880\n", - "3890\n", - "3900\n", - "3910\n", - "3920\n", - "3930\n", - "3940\n", - "3950\n", - "3960\n", - "3970\n", - "3980\n", - "3990\n", - "4000\n", - "4010\n", - "4020\n", - "4030\n", - "4040\n", - "4050\n", - "4060\n", - "4070\n", - "4080\n", - "4090\n", - "4100\n", - "4110\n", - "4120\n", - "4130\n", - "4140\n", - "4150\n", - "4160\n", - "4170\n", - "4180\n", - "4190\n", - "4200\n", - "4210\n", - "4220\n", - "4230\n", - "4240\n", - "4250\n", - "4260\n", - "4270\n", - "4280\n", - "4290\n", - "4300\n", - "4310\n", - "4320\n", - "4330\n", - "4340\n", - "4350\n", - "4360\n", - "4370\n", - "4380\n", - "4390\n", - "4400\n", - "4410\n", - "4420\n", - "4430\n", - "4440\n", - "4450\n", - "4460\n", - "4470\n", - "4480\n", - "4490\n", - "4500\n", - "4510\n", - "4520\n", - "4530\n", - "4540\n", - "4550\n", - "4560\n", - "4570\n", - "4580\n", - "4590\n", - "4600\n", - "4610\n", - "4620\n", - "4630\n", - "4640\n", - "4650\n", - "4660\n", - "4670\n", - "4680\n", - "4690\n", - "4700\n", - "4710\n", - "4720\n", - "4730\n", - "4740\n", - "4750\n", - "4760\n", - "4770\n", - "4780\n", - "4790\n", - "4800\n", - "4810\n", - "4820\n", - "4830\n", - "4840\n", - "4850\n", - "4860\n", - "4870\n", - "4880\n", - "4890\n", - "4900\n", - "4910\n", - "4920\n", - "4930\n", - "4940\n", - "4950\n", - "4960\n", - "4970\n", - "4980\n", - "4990\n", - "5000\n", - "5010\n", - "5020\n", - "5030\n", - "5040\n", - "5050\n", - "5060\n", - "5070\n", - "5080\n", - "5090\n", - "5100\n", - "5110\n", - "5120\n", - "5130\n", - "5140\n", - "5150\n", - "5160\n", - "5170\n", - "5180\n", - "5190\n", - "5200\n", - "5210\n", - "5220\n", - "5230\n", - "5240\n", - "5250\n", - "5260\n", - "5270\n", - "5280\n", - "5290\n", - "5300\n", - "5310\n", - "5320\n", - "5330\n", - "5340\n", - "5350\n", - "5360\n", - "5370\n", - "5380\n", - "5390\n", - "5400\n", - "5410\n", - "5420\n", - "5430\n", - "5440\n", - "5450\n", - "5460\n", - "5470\n", - "5480\n", - "5490\n", - "5500\n", - "5510\n", - "5520\n", - "5530\n", - "5540\n", - "5550\n", - "5560\n", - "5570\n", - "5580\n", - "5590\n", - "5600\n", - "5610\n", - "5620\n", - "5630\n", - "5640\n", - "5650\n", - "5660\n", - "5670\n", - "5680\n", - "5690\n", - "5700\n", - "5710\n", - "5720\n", - "5730\n", - "5740\n", - "5750\n", - "5760\n", - "5770\n", - "5780\n", - "5790\n", - "5800\n", - "5810\n", - "5820\n", - "5830\n", - "5840\n", - "5850\n", - "5860\n", - "5870\n", - "5880\n", - "5890\n", - "5900\n", - "5910\n", - "5920\n", - "5930\n", - "5940\n", - "5950\n", - "5960\n", - "5970\n", - "5980\n", - "5990\n", - "6000\n", - "6010\n", - "6020\n", - "6030\n", - "6040\n", - "6050\n", - "6060\n", - "6070\n", - "6080\n", - "6090\n", - "6100\n", - "6110\n", - "6120\n", - "6130\n", - "6140\n", - "6150\n", - "6160\n", - "6170\n", - "6180\n", - "6190\n", - "6200\n", - "6210\n", - "6220\n", - "6230\n", - "6240\n", - "6250\n", - "6260\n", - "6270\n", - "6280\n", - "6290\n", - "6300\n", - "6310\n", - "6320\n", - "6330\n", - "6340\n", - "6350\n", - "6360\n", - "6370\n", - "6380\n", - "6390\n", - "6400\n", - "6410\n", - "6420\n", - "6430\n", - "6440\n", - "6450\n", - "6460\n", - "6470\n", - "6480\n", - "6490\n", - "6500\n", - "6510\n", - "6520\n", - "6530\n", - "6540\n", - "6550\n", - "6560\n", - "6570\n", - "6580\n", - "6590\n", - "6600\n", - "6610\n", - "6620\n", - "6630\n", - "6640\n", - "6650\n", - "6660\n", - "6670\n", - "6680\n", - "6690\n", - "6700\n", - "6710\n", - "6720\n", - "6730\n", - "6740\n", - "6750\n", - "6760\n", - "6770\n", - "6780\n", - "6790\n", - "6800\n", - "6810\n", - "6820\n", - "6830\n", - "6840\n", - "6850\n", - "6860\n", - "6870\n", - "6880\n", - "6890\n", - "6900\n", - "6910\n", - "6920\n", - "6930\n", - "6940\n", - "6950\n", - "6960\n", - "6970\n", - "6980\n", - "6990\n", - "7000\n", - "7010\n", - "7020\n", - "7030\n", - "7040\n", - "7050\n", - "7060\n", - "7070\n", - "7080\n", - "7090\n", - "7100\n", - "7110\n", - "7120\n", - "7130\n", - "7140\n", - "7150\n", - "7160\n", - "7170\n", - "7180\n", - "7190\n", - "7200\n", - "7210\n", - "7220\n", - "7230\n", - "7240\n", - "7250\n", - "7260\n", - "7270\n", - "7280\n", - "7290\n", - "7300\n", - "7310\n", - "7320\n", - "7330\n", - "7340\n", - "7350\n", - "7360\n", - "7370\n", - "7380\n", - "7390\n", - "7400\n", - "7410\n", - "7420\n", - "7430\n", - "7440\n", - "7450\n", - "7460\n", - "7470\n", - "7480\n", - "7490\n", - "7500\n", - "7510\n", - "7520\n", - "7530\n", - "7540\n", - "7550\n", - "7560\n", - "7570\n", - "7580\n", - "7590\n", - "7600\n", - "7610\n", - "7620\n", - "7630\n", - "7640\n", - "7650\n", - "7660\n", - "7670\n", - "7680\n", - "7690\n", - "7700\n", - "7710\n", - "7720\n", - "7730\n", - "7740\n", - "7750\n", - "7760\n", - "7770\n", - "7780\n", - "7790\n", - "7800\n", - "7810\n", - "7820\n", - "7830\n", - "7840\n", - "7850\n", - "7860\n", - "7870\n", - "7880\n", - "7890\n", - "7900\n", - "7910\n", - "7920\n", - "7930\n", - "7940\n", - "7950\n", - "7960\n", - "7970\n", - "7980\n", - "7990\n", - "8000\n", - "8010\n", - "8020\n", - "8030\n", - "8040\n", - "8050\n", - "8060\n", - "8070\n", - "8080\n", - "8090\n", - "8100\n", - "8110\n", - "8120\n", - "8130\n", - "8140\n", - "8150\n", - "8160\n", - "8170\n", - "8180\n", - "8190\n", - "8200\n", - "8210\n", - "8220\n", - "8230\n", - "8240\n", - "8250\n", - "8260\n", - "8270\n", - "8280\n", - "8290\n", - "8300\n", - "8310\n", - "8320\n", - "8330\n", - "8340\n", - "8350\n", - "8360\n", - "8370\n", - "8380\n", - "8390\n", - "8400\n", - "8410\n", - "8420\n", - "8430\n", - "8440\n", - "8450\n", - "8460\n", - "8470\n", - "8480\n", - "8490\n", - "8500\n", - "8510\n", - "8520\n", - "8530\n", - "8540\n", - "8550\n", - "8560\n", - "8570\n", - "8580\n", - "8590\n", - "8600\n", - "8610\n", - "8620\n", - "8630\n", - "8640\n", - "8650\n", - "8660\n", - "8670\n", - "8680\n", - "8690\n", - "8700\n", - "8710\n", - "8720\n", - "8730\n", - "8740\n", - "8750\n", - "8760\n", - "8770\n", - "8780\n", - "8790\n", - "8800\n", - "8810\n", - "8820\n", - "8830\n", - "8840\n", - "8850\n", - "8860\n", - "8870\n", - "8880\n", - "8890\n", - "8900\n", - "8910\n", - "8920\n", - "8930\n", - "8940\n", - "8950\n", - "8960\n", - "8970\n", - "8980\n", - "8990\n" - ] - } - ], - "source": [ - "fs = 22050\n", - "t = 2 ** 18 / 22050\n", - "samples = torch.arange(t * fs) / fs\n", - "\n", - "for i in range(300, 8000):\n", - " f = i\n", - " signal1 = torch.sin(2 * torch.pi * f * samples)\n", - " signal2 = torch.sin(2 * torch.pi * (f*2) * samples)\n", - " stacked_signal = torch.stack((signal1, signal2)).unsqueeze(1)\n", - " stacked_signal = stacked_signal.to(device)\n", - " loss = model(stacked_signal)\n", - " loss.backward() \n", - " if i % 10 == 0:\n", - " print(i)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "bda06495-0546-4474-ba5c-bf55e4887329", - "metadata": {}, - "outputs": [], - "source": [ - "# Sample 2 sources given start noise\n", - "noise = torch.randn(2, 1, 2 ** 18)\n", - "noise = noise.to(device)\n", - "sampled = model.sample(\n", - " noise=noise,\n", - " num_steps=10 # Suggested range: 2-50\n", - ") # [2, 1, 2 ** 18]" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "2d025c1e-3618-4801-9b9b-b4e50e41dcf7", - "metadata": {}, - "outputs": [], - "source": [ - "z = sampled[1]" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "583d4d28-7b1b-463b-8642-4975b36f38f2", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 262144])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "z.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "eeec47b7-4b99-4239-9c61-fd36ad881876", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Audio(z.cpu(), rate=22050)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "4d87215c-4f2d-410b-ac33-7cc1d9f73fac", - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'sig' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn [14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m Audio(\u001b[43msig\u001b[49m[\u001b[38;5;241m0\u001b[39m], rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m22050\u001b[39m)\n", - "\u001b[0;31mNameError\u001b[0m: name 'sig' is not defined" - ] - } - ], - "source": [ - "Audio(sig[0], rate=22050)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "2ccb733d-706a-4535-93b6-73ae2469de8a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Audio(stacked_signal[1].cpu(), rate=22050)" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "0377cc63-846b-4acf-8fa9-f1d4a2b07be4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "7999" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "i" - ] - }, - { - "cell_type": "code", - "execution_count": 70, - "id": "dcf6a106-7967-470a-932e-156b00e46ab2", - "metadata": {}, - "outputs": [], - "source": [ - "f = 4000\n", - "signal1 = torch.sin(2 * torch.pi * f * samples)\n", - "signal2 = torch.sin(2 * torch.pi * (f*2) * samples)" - ] - }, - { - "cell_type": "code", - "execution_count": 72, - "id": "fac2d679-9e68-4bcc-8119-745435d128ed", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 72, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Audio(signal1.cpu(), rate=22050)" - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "id": "ddf58e57-4660-4e1a-83e3-5909da3b42fe", - "metadata": {}, - "outputs": [], - "source": [ - "fs = 22050\n", - "f = 440\n", - "t = 2 ** 18 / 22050\n", - "samples = torch.arange(t * fs) / fs\n", - "signal = torch.sin(2 * torch.pi * f * samples)\n", - "sig = torch.stack((signal, signal)).unsqueeze(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "faef7cc2-94b0-4b85-919f-0339542570c7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 1, 262144])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sig.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "6cd94fea-3d4c-4a5b-bcba-2220fb3e9414", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "16384.0" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "262144 / 16" - ] - }, - { - "cell_type": "code", - "execution_count": 89, - "id": "a62143ce-e47b-49e8-979f-e9241068d744", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([264600])" - ] - }, - "execution_count": 89, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "signal.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "e79b1b33-1905-4ae6-9dbe-73b68eec1dc5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Audio(sig[0], rate=22050)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "a6a2bb97", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 500/500 [15:47:21<00:00, 113.68s/it] \n" - ] - } - ], - "source": [ - "epochs = 500\n", - "for i in tqdm(range(epochs)):\n", - " for batch in data:\n", - " loss = model(batch)\n", - " loss.backward()\n", - " if i % 10 == 0:\n", - " wandb.log({\"loss\": loss})\n", - " with torch.no_grad():\n", - " noise = torch.randn(1, 1, 2**17).to(device)\n", - " sampled = model.sample(noise=noise, num_steps=40)\n", - " z = sampled.view(-1)\n", - " wandb.log({f\"Audio_{i}\": wandb.Audio(z.cpu().numpy(), sample_rate=SAMPLE_RATE)})\n", - " \n", - " \n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 259, - "id": "d18e4816", - "metadata": {}, - "outputs": [], - "source": [ - "noise = torch.randn(1, 1, 2**17)\n", - "sampled = model.sample(noise=noise, num_steps=50)" - ] - }, - { - "cell_type": "code", - "execution_count": 260, - "id": "054e708f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 1, 131072]) tensor([[[-0.4879, -0.4534, -0.4094, ..., -1.0000, 0.8554, -0.9605]]])\n" - ] - } - ], - "source": [ - "print(sampled.shape, sampled)" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "fc8becc0", - "metadata": {}, - "outputs": [], - "source": [ - "z = sampled.view(-1)\n", - "# z = z.mean(axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2c2296ba-7e43-4155-a754-349a7ee5f519", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "887fc2c1-de1a-4847-86ca-88b7c59f45fb", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "55e3555b-3f88-4a33-9fc8-a47bf5f28df7", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}