Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from preprocessing.preprocess import AudioPipeline | |
| from preprocessing.preprocess import AudioPipeline | |
| from dancer_net.dancer_net import ShortChunkCNN | |
| import os | |
| import json | |
| from functools import cache | |
| import pandas as pd | |
| def get_model(device) -> tuple[ShortChunkCNN, np.ndarray]: | |
| model_path = "logs/20221226-230930" | |
| weights = os.path.join(model_path, "dancer_net.pt") | |
| config_path = os.path.join(model_path, "config.json") | |
| with open(config_path) as f: | |
| config = json.load(f) | |
| labels = np.array(sorted(config["classes"])) | |
| model = ShortChunkCNN(n_class=len(labels)) | |
| model.load_state_dict(torch.load(weights)) | |
| model = model.to(device).eval() | |
| return model, labels | |
| def get_pipeline(sample_rate:int) -> AudioPipeline: | |
| return AudioPipeline(input_freq=sample_rate) | |
| def get_dance_map() -> dict: | |
| df = pd.read_csv("data/dance_mapping.csv") | |
| return df.set_index("id").to_dict()["name"] | |
| def predict(audio: tuple[int, np.ndarray]) -> list[str]: | |
| sample_rate, waveform = audio | |
| expected_duration = 6 | |
| threshold = 0.5 | |
| sample_len = sample_rate * expected_duration | |
| device = "mps" | |
| audio_pipeline = get_pipeline(sample_rate) | |
| model, labels = get_model(device) | |
| if sample_len > len(waveform): | |
| raise gr.Error("You must record for at least 6 seconds") | |
| if len(waveform.shape) > 1 and waveform.shape[1] > 1: | |
| waveform = waveform.transpose(1,0) | |
| waveform = waveform.mean(axis=0, keepdims=True) | |
| else: | |
| waveform = np.expand_dims(waveform, 0) | |
| waveform = waveform[: ,:sample_len] | |
| waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1 | |
| waveform = waveform.astype("float32") | |
| waveform = torch.from_numpy(waveform) | |
| spectrogram = audio_pipeline(waveform) | |
| spectrogram = spectrogram.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| results = model(spectrogram) | |
| dance_mapping = get_dance_map() | |
| results = results.squeeze(0).detach().cpu().numpy() | |
| result_mask = results > threshold | |
| probs = results[result_mask] | |
| dances = labels[result_mask] | |
| return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance." | |
| def demo(): | |
| title = "Dance Classifier" | |
| description = "Record 6 seconds of a song and find out what dance fits the music." | |
| with gr.Blocks() as app: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| with gr.Tab("Record Song"): | |
| mic_audio = gr.Audio(source="microphone", label="Song Recording") | |
| mic_submit = gr.Button("Predict") | |
| with gr.Tab("Upload Song") as t: | |
| audio_file = gr.Audio(label="Song Audio File") | |
| audio_file_submit = gr.Button("Predict") | |
| song_samples = Path(os.path.dirname(__file__), "assets", "song-samples") | |
| example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.'] | |
| labels = gr.Label(label="Dances") | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=example_audio, | |
| inputs=audio_file, | |
| outputs=labels, | |
| fn=predict, | |
| ) | |
| audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels) | |
| mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels) | |
| return app | |
| if __name__ == "__main__": | |
| demo().launch() |