Spaces:
Runtime error
Runtime error
| import torchaudio | |
| from preprocessing.preprocess import AudioPipeline | |
| from dancer_net.dancer_net import ShortChunkCNN | |
| import torch | |
| import numpy as np | |
| import os | |
| import json | |
| if __name__ == "__main__": | |
| audio_file = "data/samples/mzm.iqskzxzx.aac.p.m4a.wav" | |
| seconds = 6 | |
| model_path = "logs/20221226-230930" | |
| weights = os.path.join(model_path, "dancer_net.pt") | |
| config_path = os.path.join(model_path, "config.json") | |
| device = "mps" | |
| threshold = 0.5 | |
| with open(config_path) as f: | |
| config = json.load(f) | |
| labels = np.array(sorted(config["classes"])) | |
| audio_pipeline = AudioPipeline() | |
| waveform, sample_rate = torchaudio.load(audio_file) | |
| waveform = waveform[:, :seconds * sample_rate] | |
| spectrogram = audio_pipeline(waveform) | |
| spectrogram = spectrogram.unsqueeze(0).to(device) | |
| model = ShortChunkCNN(n_class=len(labels)) | |
| model.load_state_dict(torch.load(weights)) | |
| model = model.to(device).eval() | |
| with torch.no_grad(): | |
| results = model(spectrogram) | |
| results = results.squeeze(0).detach().cpu().numpy() | |
| results = results > threshold | |
| results = labels[results] | |
| print(results) | |