Spaces:
Build error
Build error
| """ | |
| TODO: train a linear probe | |
| usage: | |
| python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output | |
| """ | |
| from pathlib import Path | |
| from typing import List | |
| import audiotools as at | |
| from audiotools import AudioSignal | |
| import argbind | |
| import torch | |
| import numpy as np | |
| import zipfile | |
| import json | |
| from vampnet.interface import Interface | |
| import tqdm | |
| # bind the Interface to argbind | |
| Interface = argbind.bind(Interface) | |
| DEBUG = False | |
| def smart_plotly_export(fig, save_path: Path): | |
| img_format = save_path.suffix[1:] | |
| if img_format == "html": | |
| fig.write_html(save_path) | |
| elif img_format == 'bytes': | |
| return fig.to_image(format='png') | |
| #TODO: come back and make this prettier | |
| elif img_format == 'numpy': | |
| import io | |
| from PIL import Image | |
| def plotly_fig2array(fig): | |
| #convert Plotly fig to an array | |
| fig_bytes = fig.to_image(format="png", width=1200, height=700) | |
| buf = io.BytesIO(fig_bytes) | |
| img = Image.open(buf) | |
| return np.asarray(img) | |
| return plotly_fig2array(fig) | |
| elif img_format == 'jpeg' or 'png' or 'webp': | |
| fig.write_image(save_path) | |
| else: | |
| raise ValueError("invalid image format") | |
| def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="tsne"): | |
| """ | |
| dimensionality reduction for visualization! | |
| saves an html plotly figure to save_path | |
| parameters: | |
| annotated_embeddings (list): the annotated enmbeddings to be reduced; embeddings have shape (samples, features) | |
| labels (list): list of labels for embedding | |
| save_path (str): path where u wanna save ur figure | |
| method (str): umap, tsne, or pca | |
| title (str): title for ur figure | |
| returns: | |
| proj (np.ndarray): projection vector with shape (samples, dimensions) | |
| """ | |
| import pandas as pd | |
| import plotly.express as px | |
| fig_name = f"vampnet-embeddings-layer={layer}" | |
| fig_title = f"{fig_name}_{method}" | |
| save_path = (output_dir / fig_name).with_suffix(".html") | |
| if method == "umap": | |
| from umap import UMAP | |
| reducer = umap.UMAP(n_components=n_components) | |
| elif method == "tsne": | |
| from sklearn.manifold import TSNE | |
| reducer = TSNE(n_components=n_components) | |
| elif method == "pca": | |
| from sklearn.decomposition import PCA | |
| reducer = PCA(n_components=n_components) | |
| else: | |
| raise ValueError(f"invalid method: {method}") | |
| labels = [emb.label for emb in annotated_embeddings] | |
| names = [emb.filename for emb in annotated_embeddings] | |
| embs = [emb.embedding for emb in annotated_embeddings] | |
| embs_at_layer = np.stack(embs)[:, layer, :] | |
| projs = reducer.fit_transform(embs_at_layer) | |
| df = pd.DataFrame( | |
| { | |
| "label": labels, | |
| "name": names, | |
| "x": projs[:, 0], | |
| "y": projs[:, 1], | |
| } | |
| ) | |
| if n_components == 2: | |
| fig = px.scatter( | |
| df, x="x", y="y", color="label", hover_name="name", title=fig_title, | |
| ) | |
| elif n_components == 3: | |
| df['z'] = projs[:, 2] | |
| fig = px.scatter_3d( | |
| df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title | |
| ) | |
| else: | |
| raise ValueError(f"can't plot {n_components} components") | |
| fig.update_traces( | |
| marker=dict(size=6, line=dict(width=1, color="DarkSlateGrey")), | |
| selector=dict(mode="markers"), | |
| ) | |
| return smart_plotly_export(fig, save_path) | |
| # per JukeMIR, we want the emebddings from the middle layer? | |
| def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10): | |
| with torch.inference_mode(): | |
| # preprocess the signal | |
| sig = interface.preprocess(sig) | |
| # get the coarse vampnet model | |
| vampnet = interface.coarse | |
| # get the tokens | |
| z = interface.encode(sig)[:, :vampnet.n_codebooks, :] | |
| z_latents = vampnet.embedding.from_codes(z, interface.codec) | |
| # do a forward pass through the model, get the embeddings | |
| _z, embeddings = vampnet(z_latents, return_activations=True) | |
| # print(f"got embeddings with shape {embeddings.shape}") | |
| # [layer, batch, time, n_dims] | |
| # [20, 1, 600ish, 768] | |
| # squeeze batch dim (1 bc layer should be dim 0) | |
| assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}" | |
| embeddings = embeddings.squeeze(1) | |
| num_layers = embeddings.shape[0] | |
| assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers" | |
| # do meanpooling over the time dimension | |
| embeddings = embeddings.mean(dim=-2) | |
| # [20, 768] | |
| # return the embeddings | |
| return embeddings | |
| from dataclasses import dataclass, fields | |
| class AnnotatedEmbedding: | |
| label: str | |
| filename: str | |
| embedding: np.ndarray | |
| def save(self, path): | |
| """Save the Embedding object to a given path as a zip file.""" | |
| with zipfile.ZipFile(path, 'w') as archive: | |
| # Save numpy array | |
| with archive.open('embedding.npy', 'w') as f: | |
| np.save(f, self.embedding) | |
| # Save non-numpy data as json | |
| non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'} | |
| with archive.open('data.json', 'w') as f: | |
| f.write(json.dumps(non_numpy_data).encode('utf-8')) | |
| def load(cls, path): | |
| """Load the Embedding object from a given zip path.""" | |
| with zipfile.ZipFile(path, 'r') as archive: | |
| # Load numpy array | |
| with archive.open('embedding.npy') as f: | |
| embedding = np.load(f) | |
| # Load non-numpy data from json | |
| with archive.open('data.json') as f: | |
| data = json.loads(f.read().decode('utf-8')) | |
| return cls(embedding=embedding, **data) | |
| def main( | |
| path_to_audio: str = None, | |
| cache_dir: str = "./.emb_cache", | |
| output_dir: str = "./vampnet_embeddings", | |
| layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19], | |
| method: str = "tsne", | |
| n_components: int = 2, | |
| ): | |
| path_to_audio = Path(path_to_audio) | |
| assert path_to_audio.exists(), f"{path_to_audio} does not exist" | |
| cache_dir = Path(cache_dir) | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| # load our interface | |
| # argbind will automatically load the default config, | |
| interface = Interface() | |
| # we expect path_to_audio to consist of a folder for each label, so let's get the list of labels | |
| labels = [Path(x).name for x in path_to_audio.iterdir() if x.is_dir()] | |
| print(f"Found {len(labels)} labels") | |
| print(f"labels: {labels}") | |
| # collect audio files, labels, and embeddings | |
| annotated_embeddings = [] | |
| for label in labels: | |
| audio_files = list(at.util.find_audio(path_to_audio / label)) | |
| print(f"Found {len(audio_files)} audio files for label {label}") | |
| for audio_file in tqdm.tqdm(audio_files, desc=f"embedding label {label}"): | |
| # check if we have a cached embedding for this file | |
| cached_path = cache_dir / f"{label}_{audio_file.stem}.emb" | |
| if cached_path.exists(): | |
| # if so, load it | |
| if DEBUG: | |
| print(f"loading cached embedding for {cached_path.stem}") | |
| embedding = AnnotatedEmbedding.load(cached_path) | |
| else: | |
| try: | |
| sig = AudioSignal(audio_file) | |
| except Exception as e: | |
| print(f"failed to load {audio_file.name} with error {e}") | |
| print(f"skipping {audio_file.name}") | |
| continue | |
| # gets the embedding | |
| emb = vampnet_embed(sig, interface).cpu().numpy() | |
| # create an embedding we can save/load | |
| embedding = AnnotatedEmbedding( | |
| label=label, filename=audio_file.name, embedding=emb | |
| ) | |
| # cache the embeddings | |
| cached_path.parent.mkdir(exist_ok=True, parents=True) | |
| embedding.save(cached_path) | |
| annotated_embeddings.append(embedding) | |
| # now, let's do a dim reduction on the embeddings and visualize them. | |
| for layer in tqdm.tqdm(layers, desc="dim reduction"): | |
| dim_reduce( | |
| annotated_embeddings, | |
| layer, | |
| output_dir=output_dir, | |
| n_components=n_components, | |
| method=method, | |
| ) | |
| if __name__ == "__main__": | |
| args = argbind.parse_args() | |
| with argbind.scope(args): | |
| main() | |