File size: 3,232 Bytes
0bd62e5 |
1 |
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: musical_instrument_identification\n", "### This demo identifies musical instruments from an audio file. It uses Gradio's Audio and Label components.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio torch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 librosa==0.9.2 gdown"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/musical_instrument_identification/data_setups.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import torch\n", "import torchaudio\n", "from timeit import default_timer as timer\n", "from data_setups import audio_preprocess, resample\n", "import gdown\n", "\n", "url = 'https://drive.google.com/uc?id=1X5CR18u0I-ZOi_8P0cNptCe5JGk9Ro0C'\n", "output = 'piano.wav'\n", "gdown.download(url, output, quiet=False)\n", "url = 'https://drive.google.com/uc?id=1W-8HwmGR5SiyDbUcGAZYYDKdCIst07__'\n", "output= 'torch_efficientnet_fold2_CNN.pth'\n", "gdown.download(url, output, quiet=False)\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "SAMPLE_RATE = 44100\n", "AUDIO_LEN = 2.90\n", "model = torch.load(\"torch_efficientnet_fold2_CNN.pth\", map_location=torch.device('cpu'))\n", "LABELS = [\n", " \"Cello\", \"Clarinet\", \"Flute\", \"Acoustic Guitar\", \"Electric Guitar\", \"Organ\", \"Piano\", \"Saxophone\", \"Trumpet\", \"Violin\", \"Voice\"\n", "]\n", "example_list = [\n", " [\"piano.wav\"]\n", "]\n", "\n", "def predict(audio_path):\n", " start_time = timer()\n", " wavform, sample_rate = torchaudio.load(audio_path)\n", " wav = resample(wavform, sample_rate, SAMPLE_RATE)\n", " if len(wav) > int(AUDIO_LEN * SAMPLE_RATE):\n", " wav = wav[:int(AUDIO_LEN * SAMPLE_RATE)]\n", " else:\n", " print(f\"input length {len(wav)} too small!, need over {int(AUDIO_LEN * SAMPLE_RATE)}\")\n", " return\n", " img = audio_preprocess(wav, SAMPLE_RATE).unsqueeze(0)\n", " model.eval()\n", " with torch.inference_mode():\n", " pred_probs = torch.softmax(model(img), dim=1)\n", " pred_labels_and_probs = {LABELS[i]: float(pred_probs[0][i]) for i in range(len(LABELS))}\n", " pred_time = round(timer() - start_time, 5)\n", " return pred_labels_and_probs, pred_time\n", "\n", "demo = gr.Interface(fn=predict,\n", " inputs=gr.Audio(type=\"filepath\"),\n", " outputs=[gr.Label(num_top_classes=11, label=\"Predictions\"),\n", " gr.Number(label=\"Prediction time (s)\")],\n", " examples=example_list,\n", " cache_examples=False\n", " )\n", "\n", "demo.launch(debug=False)\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} |