diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..ec56e55e671ac22dffab21fe340bb7bebb836872 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/logo_animation_256.gif filter=lfs diff=lfs merge=lfs -text +assets/screenshot.png filter=lfs diff=lfs merge=lfs -text diff --git a/CHANGELOG.md b/CHANGELOG.md index fe15d47ed6a6b77862867a8e83215e8ffd8b0ede..d7335673d8e567c032854c691ed2673419bcc0da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,69 @@ -## [0.0.2a2] - 2023-07-20 +# Changelog + +All notable changes to this project will be documented in this file. + +## [1.2.Surn] - 2025-04-02 + +Implemented Unlimited Music Generation (UMG) with the [hf checkpoints](https://huggingface.co/facebook/unlimited-music-generation). + +## [1.4.0a2] - 2025-01-14 + +Add training and inference code for JASCO (https://arxiv.org/abs/2406.10970) along with the [hf checkpoints](https://huggingface.co/facebook/jasco-chords-drums-melody-1B). + +## [1.4.0a1] - 2024-06-03 + +Adding new metric PesqMetric ([Perceptual Evaluation of Speech Quality](https://doi.org/10.5281/zenodo.6549559)) + +Adding multiple audio augmentation functions: generating pink noises, up-/downsampling, low-/highpass filtering, banpass filtering, smoothing, duck masking, boosting. All are wrapped in the `audiocraft.utils.audio_effects.AudioEffects` and can be called with the API `audiocraft.utils.audio_effects.select_audio_effects`. + +Add training code for AudioSeal (https://arxiv.org/abs/2401.17264) along with the [hf checkpoints]( https://huggingface.co/facebook/audioseal). + +## [1.3.0] - 2024-05-02 + +Adding the MAGNeT model (https://arxiv.org/abs/2401.04577) along with hf checkpoints and a gradio demo app. + +Typo fixes. + +Fixing setup.py to install only audiocraft, not the unit tests and scripts. + +Fix FSDP support with PyTorch 2.1.0. + +## [1.2.0] - 2024-01-11 -Music Generation set to a max of 720 seconds (12 minutes) to avoid memory issues. +Adding stereo models. -Video editing options (thanks @Surn and @oncorporation). +Fixed the commitment loss, which was until now only applied to the first RVQ layer. -Music Conditioning segment options +Removed compression model state from the LM checkpoints, for consistency, it +should always be loaded from the original `compression_model_checkpoint`. -## [0.0.2a] - TBD +## [1.1.0] - 2023-11-06 + +Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons. + +Fixed DAC support with non default number of codebooks. + +Fixed bug when `two_step_cfg` was overriden when calling `generate()`. + +Fixed samples being always prompted with audio, rather than having both prompted and unprompted. + +**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release. + The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners. + We removed it, so you might need to retrain models. + +**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before). + +**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one + retrained a model with this pattern, so hopefully this won't impact you! + + +## [1.0.0] - 2023-09-07 + +Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion. +Added pretrained model for AudioGen and MultiBandDiffusion. + +## [0.0.2] - 2023-08-01 Improved demo, fixed top p (thanks @jnordberg). @@ -24,10 +80,3 @@ Note that other implementations exist: https://github.com/camenduru/MusicGen-col ## [0.0.1] - 2023-06-09 Initial release, with model evaluation only. - - -# Changelog - -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). diff --git a/README.md b/README.md index cc215871916c54331b24058899924c8324ed850a..3d47b6fa5a98b486643e47282a45de2ee41bd7ab 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,21 @@ emoji: 🎼 colorFrom: gray colorTo: red sdk: gradio -sdk_version: 3.38.0 +sdk_version: 5.23.3 +python_version: 3.12.8 app_file: app.py -pinned: false +pinned: true license: creativeml-openrail-m tags: - musicgen - unlimited +- user history +- metadata +hf_oauth: true +disable_embedding: true +short_description: 'unlimited Audio generation with a few added features ' +thumbnail: >- + https://cdn-uploads.huggingface.co/production/uploads/6346595c9e5f0fe83fc60444/Z8E8OaKV84zuVAvvGpMDJ.png --- [arxiv]: https://arxiv.org/abs/2306.05284 @@ -18,7 +26,18 @@ tags: Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference # UnlimitedMusicGen -This is my modification of the Audiocraft project to enable unlimited Audio generation. I have added a few features to the original project to enable this. I have also added a few features to the gradio interface to make it easier to use. +Charles Fettinger's modification of the Audiocraft project to enable unlimited Audio generation. I have added a few features to the original project to enable this. I have also added a few features to the gradio interface to make it easier to use. + +Please review my other AI relalated spaces at https://huggingface.co/Surn + +Check your video's generative metadata with https://mediaarea.net/en/MediaInfo + +Also note that I wrote an extension to Gradio for the waveform in the video after v4.48.0 removed it. + +The key update here is in the extend utility. We segment melody input and then condition the next segment with current tensors and tensors from the current time in the conditioning melody file. +This allows us to follow the same arraingement of the original melody. + +**Thank you Huggingface for the community grant to run this project**!! # Audiocraft ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg) @@ -46,12 +65,12 @@ Check out our [sample page][musicgen_samples] or test the available demo! We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data. ## Installation -Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following: +Audiocraft requires Python 3.9, PyTorch 2.1.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following: ```shell # Best to make sure you have torch installed first, in particular before installing xformers. # Don't run this if you already have PyTorch installed. -pip install 'torch>=2.0' +pip install 'torch>=2.1' # Then proceed to one of the following pip install -U audiocraft # stable release pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge @@ -60,7 +79,7 @@ pip install -e . # or if you cloned the repo locally ## Usage We offer a number of way to interact with MusicGen: -1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support). +1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/Surn/UnlimitedMusicGen) (huge thanks to all the HF team for their support). 2. You can run the Gradio demo in Colab: [colab notebook](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing). 3. You can use the gradio demo locally by running `python app.py`. 4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU). @@ -178,6 +197,25 @@ For more details on using the MusicGen model for inference using the 🤗 Transf [MusicGen docs](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) or the hands-on [Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb). +## User History + +User History is a plugin that you can add to your Spaces to cache generated images for your users. + +Key features: +- 🤗 Sign in with Hugging Face +- Save generated image, video, audio and document files with their metadata: prompts, timestamp, hyper-parameters, etc. +- Export your history as zip. +- Delete your history to respect privacy. +- Compatible with Persistent Storage for long-term storage. +- Admin panel to check configuration and disk usage . + +Useful links: +- Demo: https://huggingface.co/spaces/Wauplin/gradio-user-history +- README: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/README.md +- Source file: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/user_history.py +- Discussions: https://huggingface.co/spaces/Wauplin/gradio-user-history/discussions + +![Image preview](./assets/screenshot.png) ## Model Card @@ -212,4 +250,7 @@ Check [@camenduru tutorial on Youtube](https://www.youtube.com/watch?v=EGfxuTy9E * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE). * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights). [arxiv]: https://arxiv.org/abs/2306.05284 -[musicgen_samples]: https://ai.honu.io/papers/musicgen/ \ No newline at end of file + +[arxiv]: https://arxiv.org/abs/2306.05284 +[musicgen_samples]: https://ai.honu.io/papers/musicgen/ +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py index b57c52f8b3bd37d65628adf5cfab66a4b60c694b..7f083731e69338cc164c9a2fd87cc528b6cb2dcb 100644 --- a/app.py +++ b/app.py @@ -21,11 +21,17 @@ from audiocraft.models import MusicGen from audiocraft.data.audio import audio_write from audiocraft.data.audio_utils import apply_fade, apply_tafade, apply_splice_effect from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, INTERRUPTING +from audiocraft.utils import utils import numpy as np import random -#from pathlib import Path +import shutil +from mutagen.mp4 import MP4 #from typing import List, Union import librosa +import modules.user_history +from modules.version_info import versions_html, commit_hash, get_xformers_version +from modules.gradio import * +from modules.file_utils import get_file_parts, get_filename_from_filepath, convert_title_to_filename, get_filename, delete_file MODEL = None MODELS = None @@ -35,7 +41,12 @@ UNLOAD_MODEL = False MOVE_TO_CPU = False MAX_PROMPT_INDEX = 0 git = os.environ.get('GIT', "git") -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" +#s.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" +os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' +os.environ['CUDA_MODULE_LOADING']='LAZY' +os.environ['USE_FLASH_ATTENTION'] = '1' +os.environ['XFORMERS_FORCE_DISABLE_TRITON']= '1' def interrupt_callback(): return INTERRUPTED @@ -72,7 +83,7 @@ def toggle_audio_src(choice): else: return gr.update(source="upload", value=None, label="File") -def make_waveform(*args, **kwargs): +def get_waveform(*args, **kwargs): # Further remove some warnings. be = time.time() with warnings.catch_warnings(): @@ -80,6 +91,7 @@ def make_waveform(*args, **kwargs): out = gr.make_waveform(*args, **kwargs) print("Make a video took", time.time() - be) return out + def load_model(version): global MODEL, MODELS, UNLOAD_MODEL @@ -102,32 +114,12 @@ def load_model(version): print("Cached model loaded in %.2fs" % (time.monotonic() - t1)) return result -def get_filename(file): - # extract filename from file object - filename = None - if file is not None: - filename = file.name - return filename - -def get_filename_from_filepath(filepath): - file_name = os.path.basename(filepath) - file_base, file_extension = os.path.splitext(file_name) - return file_base, file_extension - def get_melody(melody_filepath): audio_data= list(librosa.load(melody_filepath, sr=None)) audio_data[0], audio_data[1] = audio_data[1], audio_data[0] melody = tuple(audio_data) return melody - -def commit_hash(): - try: - return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip() - except Exception: - return "" - - def git_tag(): try: return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip() @@ -140,28 +132,6 @@ def git_tag(): except Exception: return "" -def versions_html(): - import torch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = commit_hash() - #tag = git_tag() - - import xformers - xformers_version = xformers.__version__ - - return f""" - version: " else commit}" target="_blank">{"huggingface" if commit == "" else commit} -  •  - python: {python_version} -  •  - torch: {getattr(torch, '__long_version__',torch.__version__)} -  •  - xformers: {xformers_version} -  •  - gradio: {gr.__version__} - """ - def load_melody_filepath(melody_filepath, title): # get melody filename #$Union[str, os.PathLike] @@ -187,12 +157,13 @@ def load_melody_filepath(melody_filepath, title): print(f"Melody length: {len(melody_data)}, Melody segments: {total_melodys}\n") MAX_PROMPT_INDEX = total_melodys - return gr.Textbox.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=0), gr.update(value="melody-large", interactive=True) + return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=0), gr.update(value="melody", interactive=True) def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False): global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU output_segments = None melody_name = "Not Used" + melody_extension = "Not Used" melody = None if melody_filepath: melody_name, melody_extension = get_filename_from_filepath(melody_filepath) @@ -201,17 +172,23 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe INTERRUPTED = False INTERRUPTING = False if temperature < 0: + temperature -0 raise gr.Error("Temperature must be >= 0.") if topk < 0: + topk = 1 raise gr.Error("Topk must be non-negative.") if topp < 0: + topp =1 raise gr.Error("Topp must be non-negative.") - if MODEL is None or MODEL.name != model: - MODEL = load_model(model) - else: - if MOVE_TO_CPU: - MODEL.to('cuda') + try: + if MODEL is None or MODEL.name != model: + MODEL = load_model(model) + else: + if MOVE_TO_CPU: + MODEL.to('cuda') + except Exception as e: + raise gr.Error(f"Error loading model '{model}': {str(e)}. Try a different model.") # prevent hacking duration = min(duration, 720) @@ -251,35 +228,41 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe rep_penalty=0.5 ) - if melody: - # todo return excess duration, load next model and continue in loop structure building up output_segments - if duration > MODEL.lm.cfg.dataset.segment_duration: - output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.lm.cfg.dataset.segment_duration, prompt_index, harmony_only=False) - else: - # pure original code - sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0) - print(melody.shape) - if melody.dim() == 2: - melody = melody[None] - melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)] - output = MODEL.generate_with_chroma( - descriptions=[text], - melody_wavs=melody, - melody_sample_rate=sr, - progress=False - ) - # All output_segments are populated, so we can break the loop or set duration to 0 - break - else: - #output = MODEL.generate(descriptions=[text], progress=False) - if not output_segments: - next_segment = MODEL.generate(descriptions=[text], progress=False) - duration -= segment_duration + try: + if melody: + # return excess duration, load next model and continue in loop structure building up output_segments + if duration > MODEL.lm.cfg.dataset.segment_duration: + output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.lm.cfg.dataset.segment_duration, prompt_index, harmony_only=False) + else: + # pure original code + sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0) + print(melody.shape) + if melody.dim() == 2: + melody = melody[None] + melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)] + output = MODEL.generate_with_chroma( + descriptions=[text], + melody_wavs=melody, + melody_sample_rate=sr, + progress=True + ) + # All output_segments are populated, so we can break the loop or set duration to 0 + break else: - last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:] - next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=False) - duration -= segment_duration - overlap - output_segments.append(next_segment) + #output = MODEL.generate(descriptions=[text], progress=False) + if not output_segments: + next_segment = MODEL.generate(descriptions=[text], progress=True) + duration -= segment_duration + else: + last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:] + next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=True) + duration -= segment_duration - overlap + if next_segment != None: + output_segments.append(next_segment) + except Exception as e: + print(f"Error generating audio: {e}") + gr.Error(f"Error generating audio: {e}") + return None, None, seed if INTERRUPTING: INTERRUPTED = True @@ -287,6 +270,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe print("Function execution interrupted!") raise gr.Error("Interrupted.") + print(f"\nOutput segments: {len(output_segments)}\n") if output_segments: try: # Combine the output segments into one long audio file or stack tracks @@ -312,7 +296,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe ##overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=1) #stack tracks ##print(f" overlap size stack:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}") #overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=2) #stack tracks - #print(f" overlap size cat:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}") + #print(f" overlap size cat:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}") output = torch.cat([output[:, :, :-overlap_samples], overlapping_output, output_segments[i][:, :, overlap_samples:]], dim=dimension) else: output = torch.cat([output, output_segments[i]], dim=dimension) @@ -321,143 +305,227 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe print(f"Error combining segments: {e}. Using the first segment only.") output = output_segments[0].detach().cpu().float()[0] else: - output = output.detach().cpu().float()[0] - - with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: + if (output is None) or (output.dim() == 0): + return None, None, seed + else: + output = output.detach().cpu().float()[0] + profile: gr.OAuthProfile | None = None + title_file_name = convert_title_to_filename(title) + with NamedTemporaryFile("wb", suffix=".wav", delete=False, prefix = title_file_name) as file: video_description = f"{text}\n Duration: {str(initial_duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}\n Model: {model}\n Melody Condition:{melody_name}\n Sample Segment: {prompt_index}" if include_settings or include_title: background = add_settings_to_image(title if include_title else "", video_description if include_settings else "", background_path=background, font=settings_font, font_color=settings_font_color) audio_write( file.name, output, MODEL.sample_rate, strategy="loudness", - loudness_headroom_db=18, loudness_compressor=True, add_suffix=False, channels=2) - waveform_video = make_waveform(file.name,bg_image=background, bar_count=45) + loudness_headroom_db=18, loudness_compressor=True, add_suffix=False, channels=2) + waveform_video_path = get_waveform(file.name,bg_image=background, bar_count=45, name = title_file_name) + # Remove the extension from file.name + file_name_without_extension = os.path.splitext(file.name)[0] + # Get the directory, filename, name, extension, and new extension of the waveform video path + video_dir, video_name, video_name, video_ext, video_new_ext = get_file_parts(waveform_video_path) + + new_video_path = os.path.join(video_dir, title_file_name + video_new_ext) + + mp4 = MP4(waveform_video_path) + mp4["©nam"] = title_file_name # Title tag + mp4["desc"] = f"{text}\n Duration: {str(initial_duration)}" # Description tag + + commit = commit_hash() + metadata={ + "prompt": text, + "negative_prompt": "", + "Seed": seed, + "steps": 1, + "width": "768px", + "height":"512px", + "Dimension": dimension, + "Top-k": topk, + "Top-p":topp, + "Randomness": temperature, + "cfg":cfg_coef, + "overlap": overlap, + "Melody Condition": melody_name, + "Sample Segment": prompt_index, + "Duration": initial_duration, + "Audio": file.name, + "font": settings_font, + "font_color": settings_font_color, + "harmony_only": harmony_only, + "background": background, + "include_title": include_title, + "include_settings": include_settings, + "profile": profile, + "commit": commit_hash(), + "tag": git_tag(), + "version": gr.__version__, + "model_version": MODEL.version, + "model_name": MODEL.name, + "model_description": f"{MODEL.audio_channels} channels, {MODEL.sample_rate} Hz", + "melody_name" : melody_name if melody_name else "", + "melody_extension" : melody_extension if melody_extension else "", + "hostname": "https://huggingface.co/spaces/Surn/UnlimitedMusicGen", + "version" : f"""https://huggingface.co/spaces/Surn/UnlimitedMusicGen/commit/{"huggingface" if commit == "" else commit}""", + "python" : sys.version, + "torch" : getattr(torch, '__long_version__',torch.__version__), + "xformers": get_xformers_version(), + "gradio": gr.__version__, + "huggingface_space": os.environ.get('SPACE_ID', ''), + "CUDA": f"""{"CUDA is available. device: " + torch.cuda.get_device_name(0) + " version: " + torch.version.cuda if torch.cuda.is_available() else "CUDA is not available."}""", + } + # Add additional metadata from the metadata dictionary (if it exists) + for key, value in metadata.items(): + mp4[key] = str(value) # Convert values to strings as required by mutagen + + # Save the metadata changes to the file + mp4.save() + + try: + if os.path.exists(new_video_path): + delete_file(new_video_path) + # Open the original MP4 file in binary read mode and the new file in binary write mode + with open(waveform_video_path, "rb") as src, open(new_video_path, "wb") as dst: + if os.path.exists(waveform_video_path): + # Copy the contents from the source file to the destination file + shutil.copyfileobj(src, dst) + waveform_video_path = new_video_path + except Exception as e: + print(f"Error copying file: {e}") + + if waveform_video_path: + modules.user_history.save_file( + profile=profile, + image=background, + audio=file, + video=waveform_video_path, + label=text, + metadata=metadata, + ) + + if MOVE_TO_CPU: MODEL.to('cpu') if UNLOAD_MODEL: MODEL = None torch.cuda.empty_cache() torch.cuda.ipc_collect() - return waveform_video, file.name, seed + return waveform_video_path, file.name, seed +gr.set_static_paths(paths=["fonts/","assets/"]) def ui(**kwargs): - css=""" - #col-container {max-width: 910px; margin-left: auto; margin-right: auto;} - a {text-decoration-line: underline; font-weight: 600;} - #btn-generate {background-image:linear-gradient(to right bottom, rgb(157, 255, 157), rgb(229, 255, 235));} - #btn-generate:hover {background-image:linear-gradient(to right bottom, rgb(229, 255, 229), rgb(255, 255, 255));} - #btn-generate:active {background-image:linear-gradient(to right bottom, rgb(229, 255, 235), rgb(157, 255, 157));} - #versions {margin-top: 1em; width:100%; text-align:center;} - .small-btn {max-width:75px;} - """ - with gr.Blocks(title="UnlimitedMusicGen", css=css) as demo: - gr.Markdown( - """ + with gr.Blocks(title="UnlimitedMusicGen",css_paths="style_20250331.css", theme='Surn/beeuty') as interface: + with gr.Tab("UnlimitedMusicGen"): + gr.Markdown( + """ # UnlimitedMusicGen This is your private demo for [UnlimitedMusicGen](https://github.com/Oncorporation/audiocraft), a simple and controllable model for music generation presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284) Disclaimer: This won't run on CPU only. Clone this App and run on GPU instance! - Todo: Working on improved Interrupt and new Models. - """ - ) - if IS_SHARED_SPACE and not torch.cuda.is_available(): - gr.Markdown(""" - ⚠ This Space doesn't work in this shared UI ⚠ - - - Duplicate Space - to use it privately, or use the public demo - """) - with gr.Row(): - with gr.Column(): - with gr.Row(): - text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps 48khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi") - with gr.Column(): - duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True) - model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-melody", "stereo-medium", "stereo-small", "stereo-large", "stereo-melody-large"], label="AI Model", value="melody-large", interactive=True) - with gr.Row(): - submit = gr.Button("Generate", elem_id="btn-generate") - # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. - _ = gr.Button("Interrupt", elem_id="btn-interrupt").click(fn=interrupt, queue=False) - with gr.Row(): - with gr.Column(): - radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic") - melody_filepath = gr.Audio(source="upload", type="filepath", label="Melody Condition (optional)", interactive=True, elem_id="melody-input") + Todo: Working on improved Interrupt. + Theme Available at ["Surn/Beeuty"](https://huggingface.co/spaces/Surn/Beeuty) + + """ + ) + if IS_SHARED_SPACE and not torch.cuda.is_available(): + gr.Markdown(""" + ⚠ This Space doesn't work in this shared UI ⚠ + + + Duplicate Space + to use it privately, or use the public demo + """) + with gr.Row(): with gr.Column(): - harmony_only = gr.Radio(label="Use Harmony Only",choices=["No", "Yes"], value="No", interactive=True, info="Remove Drums?") - prompt_index = gr.Slider(label="Melody Condition Sample Segment", minimum=-1, maximum=MAX_PROMPT_INDEX, step=1, value=0, interactive=True, info="Which 30 second segment to condition with, - 1 condition each segment independantly") - with gr.Accordion("Video", open=False): - with gr.Row(): - background= gr.Image(value="./assets/background.png", source="upload", label="Background", shape=(768,512), type="filepath", interactive=True) - with gr.Column(): - include_title = gr.Checkbox(label="Add Title", value=True, interactive=True) - include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True) - with gr.Row(): - title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True) - settings_font = gr.Text(label="Settings Font", value="./assets/arial.ttf", interactive=True) - settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#c87f05", interactive=True) - with gr.Accordion("Expert", open=False): - with gr.Row(): - overlap = gr.Slider(minimum=0, maximum=15, value=2, step=1, label="Verse Overlap", interactive=True) - dimension = gr.Slider(minimum=-2, maximum=2, value=2, step=1, label="Dimension", info="determines which direction to add new segements of audio. (1 = stack tracks, 2 = lengthen, -2..0 = ?)", interactive=True) - with gr.Row(): - topk = gr.Number(label="Top-k", value=280, precision=0, interactive=True) - topp = gr.Number(label="Top-p", value=1150, precision=0, interactive=True) - temperature = gr.Number(label="Randomness Temperature", value=0.7, precision=None, interactive=True) - cfg_coef = gr.Number(label="Classifier Free Guidance", value=8.5, precision=None, interactive=True) - with gr.Row(): - seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True) - gr.Button('\U0001f3b2\ufe0f', elem_classes="small-btn").click(fn=lambda: -1, outputs=[seed], queue=False) - reuse_seed = gr.Button('\u267b\ufe0f', elem_classes="small-btn") - with gr.Column() as c: - output = gr.Video(label="Generated Music") - wave_file = gr.File(label=".wav file", elem_id="output_wavefile", interactive=True) - seed_used = gr.Number(label='Seed used', value=-1, interactive=False) - - radio.change(toggle_audio_src, radio, [melody_filepath], queue=False, show_progress=False) - melody_filepath.change(load_melody_filepath, inputs=[melody_filepath, title], outputs=[title, prompt_index , model], api_name="melody_filepath_change", queue=False) - reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False, api_name="reuse_seed") - submit.click(predict, inputs=[model, text,melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap, prompt_index, include_title, include_settings, harmony_only], outputs=[output, wave_file, seed_used], api_name="submit") - gr.Examples( - fn=predict, - examples=[ - [ - "4/4 120bpm 320kbps 48khz, An 80s driving pop song with heavy drums and synth pads in the background", - "./assets/bach.mp3", - "melody", - "80s Pop Synth" + with gr.Row(): + text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps 48khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi") + with gr.Column(): + duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True) + model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-small", "stereo-medium", "stereo-large", "stereo-melody", "stereo-melody-large"], label="AI Model", value="melody", interactive=True) + with gr.Row(): + submit = gr.Button("Generate", elem_id="btn-generate") + # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. + _ = gr.Button("Interrupt", elem_id="btn-interrupt").click(fn=interrupt, queue=False) + with gr.Row(): + with gr.Column(): + radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic") + melody_filepath = gr.Audio(sources=["upload"], type="filepath", label="Melody Condition (optional)", interactive=True, elem_id="melody-input") + with gr.Column(): + harmony_only = gr.Radio(label="Use Harmony Only",choices=["No", "Yes"], value="No", interactive=True, info="Remove Drums?") + prompt_index = gr.Slider(label="Melody Condition Sample Segment", minimum=-1, maximum=MAX_PROMPT_INDEX, step=1, value=0, interactive=True, info="Which 30 second segment to condition with, - 1 condition each segment independantly") + with gr.Accordion("Video", open=False): + with gr.Row(): + background= gr.Image(value="./assets/background.png", sources=["upload"], label="Background", width=768, height=512, type="filepath", interactive=True) + with gr.Column(): + include_title = gr.Checkbox(label="Add Title", value=True, interactive=True) + include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True) + with gr.Row(): + title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True) + settings_font = gr.Text(label="Settings Font", value="./assets/arial.ttf", interactive=True) + settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#c87f05", interactive=True) + with gr.Accordion("Expert", open=False): + with gr.Row(): + overlap = gr.Slider(minimum=0, maximum=15, value=2, step=1, label="Verse Overlap", interactive=True) + dimension = gr.Slider(minimum=-2, maximum=2, value=2, step=1, label="Dimension", info="determines which direction to add new segements of audio. (1 = stack tracks, 2 = lengthen, -2..0 = ?)", interactive=True) + with gr.Row(): + topk = gr.Number(label="Top-k", value=280, precision=0, interactive=True) + topp = gr.Number(label="Top-p", value=1150, precision=0, interactive=True) + temperature = gr.Number(label="Randomness Temperature", value=0.7, precision=None, interactive=True) + cfg_coef = gr.Number(label="Classifier Free Guidance", value=8.5, precision=None, interactive=True) + with gr.Row(): + seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True) + gr.Button('\U0001f3b2\ufe0f', elem_classes="small-btn").click(fn=lambda: -1, outputs=[seed], queue=False) + reuse_seed = gr.Button('\u267b\ufe0f', elem_classes="small-btn") + with gr.Column() as c: + output = gr.Video(label="Generated Music") + wave_file = gr.File(label=".wav file", elem_id="output_wavefile", interactive=True) + seed_used = gr.Number(label='Seed used', value=-1, interactive=False) + + radio.change(toggle_audio_src, radio, [melody_filepath], queue=False, show_progress=False) + melody_filepath.change(load_melody_filepath, inputs=[melody_filepath, title], outputs=[title, prompt_index , model], api_name="melody_filepath_change", queue=False) + reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False, api_name="reuse_seed") + submit.click(predict, inputs=[model, text,melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap, prompt_index, include_title, include_settings, harmony_only], outputs=[output, wave_file, seed_used], api_name="submit") + gr.Examples( + fn=predict, + examples=[ + [ + "4/4 120bpm 320kbps 48khz, An 80s driving pop song with heavy drums and synth pads in the background", + "./assets/bach.mp3", + "stereo-melody-large", + "80s Pop Synth" + ], + [ + "4/4 120bpm 320kbps 48khz, A cheerful country song with acoustic guitars", + "./assets/bolero_ravel.mp3", + "melody", + "Country Guitar" + ], + [ + "4/4 120bpm 320kbps 48khz, 90s rock song with electric guitar and heavy drums", + None, + "stereo-medium", + "90s Rock Guitar" + ], + [ + "4/4 120bpm 320kbps 48khz, a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions", + "./assets/bach.mp3", + "melody-large", + "EDM my Bach" + ], + [ + "4/4 320kbps 48khz, lofi slow bpm electro chill with organic samples", + None, + "medium", + "LoFi Chill" + ], ], - [ - "4/4 120bpm 320kbps 48khz, A cheerful country song with acoustic guitars", - "./assets/bolero_ravel.mp3", - "melody", - "Country Guitar" - ], - [ - "4/4 120bpm 320kbps 48khz, 90s rock song with electric guitar and heavy drums", - None, - "medium", - "90s Rock Guitar" - ], - [ - "4/4 120bpm 320kbps 48khz, a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions", - "./assets/bach.mp3", - "melody", - "EDM my Bach" - ], - [ - "4/4 320kbps 48khz, lofi slow bpm electro chill with organic samples", - None, - "medium", - "LoFi Chill" - ], - ], - inputs=[text, melody_filepath, model, title], - outputs=[output] - ) - gr.HTML(value=versions_html(), visible=True, elem_id="versions") - + inputs=[text, melody_filepath, model, title], + outputs=[output] + ) + gr.HTML(value=versions_html(), visible=True, elem_id="versions") + with gr.Tab("User History") as history_tab: + modules.user_history.render() + # Show the interface launch_kwargs = {} share = kwargs.get('share', False) @@ -471,10 +539,10 @@ def ui(**kwargs): if share: launch_kwargs['share'] = share launch_kwargs['favicon_path']= "./assets/favicon.ico" + - - demo.queue(max_size=10, concurrency_count=1, api_open=False).launch(**launch_kwargs) + interface.queue(max_size=10, api_open=False).launch(**launch_kwargs) if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -518,7 +586,7 @@ if __name__ == "__main__": args = parser.parse_args() launch_kwargs = {} - launch_kwargs['server_name'] = args.listen + launch_kwargs['listen'] = args.listen if args.username and args.password: launch_kwargs['auth'] = (args.username, args.password) @@ -528,7 +596,7 @@ if __name__ == "__main__": launch_kwargs['inbrowser'] = args.inbrowser if args.share: launch_kwargs['share'] = args.share - launch_kwargs['favicon_path']= "./assets/favicon.ico" + launch_kwargs['favicon_path']= "./assets/favicon.ico" UNLOAD_MODEL = args.unload_model @@ -538,6 +606,6 @@ if __name__ == "__main__": ui( unload_to_cpu = MOVE_TO_CPU, - share=args.share - + share=args.share, + **launch_kwargs, ) diff --git a/assets/KuritaSurnLogox64.png b/assets/KuritaSurnLogox64.png new file mode 100644 index 0000000000000000000000000000000000000000..41529d29338aae8a90825a587c0dba074ca67d0f Binary files /dev/null and b/assets/KuritaSurnLogox64.png differ diff --git a/assets/Vermilion-Musical-Notes-Typography-No-Background.svg b/assets/Vermilion-Musical-Notes-Typography-No-Background.svg new file mode 100644 index 0000000000000000000000000000000000000000..6d9d192b0e6fc881e537184f6b883db57534a3c7 --- /dev/null +++ b/assets/Vermilion-Musical-Notes-Typography-No-Background.svg @@ -0,0 +1,5158 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 b/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..71be35a12d3e97993996806d6a94175568b2761f Binary files /dev/null and b/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 differ diff --git a/assets/icon_delete.png b/assets/icon_delete.png new file mode 100644 index 0000000000000000000000000000000000000000..9b1d64da9c7e919b181a9a55a20cb5764bf5c513 Binary files /dev/null and b/assets/icon_delete.png differ diff --git a/assets/icon_download.png b/assets/icon_download.png new file mode 100644 index 0000000000000000000000000000000000000000..2e2a6e55694c6347d339873cf9213f248e023489 Binary files /dev/null and b/assets/icon_download.png differ diff --git a/assets/icon_refresh.png b/assets/icon_refresh.png new file mode 100644 index 0000000000000000000000000000000000000000..ca3f53f4aa809a9add10dfe883eb7fe550089cdf Binary files /dev/null and b/assets/icon_refresh.png differ diff --git a/assets/logo_animation_256.gif b/assets/logo_animation_256.gif new file mode 100644 index 0000000000000000000000000000000000000000..71510e9ba405e1b8a720bc13f2ee6718627d6996 --- /dev/null +++ b/assets/logo_animation_256.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84aa8c95f88f4c9d110dc87c344ec92786e8b5c464ac8141a9c3b12bedf2ed71 +size 140317 diff --git a/assets/screenshot.png b/assets/screenshot.png new file mode 100644 index 0000000000000000000000000000000000000000..ee2c407f025853073946bfc4b31dc27e7da3f0ed --- /dev/null +++ b/assets/screenshot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89abfaffefc18124ffe8f0775eb4dcdbc589bb97befc76dd3f3fc48992991e2e +size 388240 diff --git a/assets/sirens_and_a_humming_engine_approach_and_pass.mp3 b/assets/sirens_and_a_humming_engine_approach_and_pass.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..e74b5b61a624fbf69f5e70febc64c91658bb38ac Binary files /dev/null and b/assets/sirens_and_a_humming_engine_approach_and_pass.mp3 differ diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py index 6b8594f470200ff5c000542ef115375ed69b749c..4f5fc0574d9ec5a2b8cff39dd8ecd5f2a81f441b 100644 --- a/audiocraft/__init__.py +++ b/audiocraft/__init__.py @@ -7,4 +7,4 @@ # flake8: noqa from . import data, modules, models -__version__ = '0.0.2a2' +__version__ = '1.4.Surn' diff --git a/audiocraft/data/__init__.py b/audiocraft/data/__init__.py index 708a3dcead8dda89374a021177481dacae9f7fe9..a0f2b08a0a50701ca6d86ff1287300f1f94ffa23 100644 --- a/audiocraft/data/__init__.py +++ b/audiocraft/data/__init__.py @@ -5,4 +5,4 @@ # LICENSE file in the root directory of this source tree. # flake8: noqa -from . import audio, audio_dataset +from . import audio, audio_dataset, info_audio_dataset diff --git a/audiocraft/data/audio.py b/audiocraft/data/audio.py index 05fa53ae8ad1b40ab8b9c5dd134227a2a58c55fe..bb98c98e4e81664f6d1cda614df88fc333b0b846 100644 --- a/audiocraft/data/audio.py +++ b/audiocraft/data/audio.py @@ -21,6 +21,7 @@ from torch.nn import functional as F import torchaudio as ta import av +import subprocess as sp from .audio_utils import f32_pcm, i16_pcm, normalize_audio, convert_audio @@ -149,7 +150,17 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., wav = F.pad(wav, (0, expected_frames - wav.shape[-1])) return wav, sr - +def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]): + # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely. + assert wav.dim() == 2, wav.shape + command = [ + 'ffmpeg', + '-loglevel', 'error', + '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]), + '-i', '-'] + flags + [str(out_path)] + input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes() + sp.run(command, input=input_, check=True) + def audio_write(stem_name: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, @@ -215,3 +226,77 @@ def audio_write(stem_name: tp.Union[str, Path], path.unlink() raise return path + +def audio_write2(stem_name: tp.Union[str, Path], + wav: torch.Tensor, sample_rate: int, + format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None, + normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1, + rms_headroom_db: float = 18, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, + log_clipping: bool = True, make_parent_dir: bool = True, + add_suffix: bool = True) -> Path: + """Convenience function for saving audio to disk. Returns the filename the audio was written to. + + Args: + stem_name (str or Path): Filename without extension which will be added automatically. + wav (torch.Tensor): Audio data to save. + sample_rate (int): Sample rate of audio data. + format (str): Either "wav", "mp3", "ogg", or "flac". + mp3_rate (int): kbps when using mp3s. + ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself. + normalize (bool): if `True` (default), normalizes according to the prescribed + strategy (see after). If `False`, the strategy is only used in case clipping + would happen. + strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', + i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square + with extra headroom to avoid clipping. 'clip' just clips. + peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. + rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger + than the `peak_clip` one to avoid further clipping. + loudness_headroom_db (float): Target loudness for loudness normalization. + loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'. + when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still + occurs despite strategy (only for 'rms'). + make_parent_dir (bool): Make parent directory if it doesn't exist. + Returns: + Path: Path of the saved audio. + """ + assert wav.dtype.is_floating_point, "wav is not floating point" + if wav.dim() == 1: + wav = wav[None] + elif wav.dim() > 2: + raise ValueError("Input wav should be at most 2 dimension.") + assert wav.isfinite().all() + wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, + rms_headroom_db, loudness_headroom_db, loudness_compressor, + log_clipping=log_clipping, sample_rate=sample_rate, + stem_name=str(stem_name)) + if format == 'mp3': + suffix = '.mp3' + flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k'] + elif format == 'wav': + suffix = '.wav' + flags = ['-f', 'wav', '-c:a', 'pcm_s16le'] + elif format == 'ogg': + suffix = '.ogg' + flags = ['-f', 'ogg', '-c:a', 'libvorbis'] + if ogg_rate is not None: + flags += ['-b:a', f'{ogg_rate}k'] + elif format == 'flac': + suffix = '.flac' + flags = ['-f', 'flac'] + else: + raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.") + if not add_suffix: + suffix = '' + path = Path(str(stem_name) + suffix) + if make_parent_dir: + path.parent.mkdir(exist_ok=True, parents=True) + try: + _piping_to_ffmpeg(path, wav, sample_rate, flags) + except Exception: + if path.exists(): + # we do not want to leave half written files around. + path.unlink() + raise + return path \ No newline at end of file diff --git a/audiocraft/data/audio_dataset.py b/audiocraft/data/audio_dataset.py index cf21422ea0059cb2d6553f93e608b8f9fa0d3a50..9d7442526186b3712f5d4754f928a40ecd964174 100644 --- a/audiocraft/data/audio_dataset.py +++ b/audiocraft/data/audio_dataset.py @@ -3,12 +3,16 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - +"""AudioDataset support. In order to handle a larger number of files +without having to scan again the folders, we precompute some metadata +(filename, sample rate, duration), and use that to efficiently sample audio segments. +""" import argparse import copy from concurrent.futures import ThreadPoolExecutor, Future from dataclasses import dataclass, fields from contextlib import ExitStack +from functools import lru_cache import gzip import json import logging @@ -81,9 +85,12 @@ class AudioMeta(BaseInfo): class SegmentInfo(BaseInfo): meta: AudioMeta seek_time: float - n_frames: int # actual number of frames without padding + # The following values are given once the audio is processed, e.g. + # at the target sample rate and target number of channels. + n_frames: int # actual number of frames without padding total_frames: int # total number of frames, padding included - sample_rate: int # actual sample rate + sample_rate: int # actual sample rate + channels: int # number of audio channels. DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'] @@ -114,8 +121,8 @@ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: Args: m (AudioMeta): Audio meta to resolve. - fast (bool): If True, uses a really fast check for determining if a file is already absolute or not. - Only valid on Linux/Mac. + fast (bool): If True, uses a really fast check for determining if a file + is already absolute or not. Only valid on Linux/Mac. Returns: AudioMeta: Audio meta with resolved path. """ @@ -151,7 +158,7 @@ def find_audio_files(path: tp.Union[Path, str], progress (bool): Whether to log progress on audio files collection. workers (int): number of parallel workers, if 0, use only the current thread. Returns: - List[AudioMeta]: List of audio file path and its metadata. + list of AudioMeta: List of audio file path and its metadata. """ audio_files = [] futures: tp.List[Future] = [] @@ -203,7 +210,7 @@ def load_audio_meta(path: tp.Union[str, Path], resolve (bool): Whether to resolve the path from AudioMeta (default=True). fast (bool): activates some tricks to make things faster. Returns: - List[AudioMeta]: List of audio file path and its total duration. + list of AudioMeta: List of audio file path and its total duration. """ open_fn = gzip.open if str(path).lower().endswith('.gz') else open with open_fn(path, 'rb') as fp: # type: ignore @@ -250,9 +257,14 @@ class AudioDataset: allows to return a tuple containing the torch Tensor and additional metadata on the segment and the original audio meta. + Note that you can call `start_epoch(epoch)` in order to get + a deterministic "randomization" for `shuffle=True`. + For a given epoch and dataset index, this will always return the same extract. + You can get back some diversity by setting the `shuffle_seed` param. + Args: - meta (tp.List[AudioMeta]): List of audio files metadata. - segment_duration (float): Optional segment duration of audio to load. + meta (list of AudioMeta): List of audio files metadata. + segment_duration (float, optional): Optional segment duration of audio to load. If not specified, the dataset will load the full audio segment from the file. shuffle (bool): Set to `True` to have the data reshuffled at every epoch. sample_rate (int): Target sample rate of the loaded audio samples. @@ -266,10 +278,19 @@ class AudioDataset: is shorter than the desired segment. max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset. return_info (bool): Whether to return the wav only or return wav along with segment info and metadata. - min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided + min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided audio shorter than this will be filtered out. - max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided + max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided audio longer than this will be filtered out. + shuffle_seed (int): can be used to further randomize + load_wav (bool): if False, skip loading the wav but returns a tensor of 0 + with the expected segment_duration (which must be provided if load_wav is False). + permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration` + are False. Will ensure a permutation on files when going through the dataset. + In that case the epoch number must be provided in order for the model + to continue the permutation across epochs. In that case, it is assumed + that `num_samples = total_batch_size * num_updates_per_epoch`, with + `total_batch_size` the overall batch size accounting for all gpus. """ def __init__(self, meta: tp.List[AudioMeta], @@ -285,16 +306,14 @@ class AudioDataset: max_read_retry: int = 10, return_info: bool = False, min_audio_duration: tp.Optional[float] = None, - max_audio_duration: tp.Optional[float] = None + max_audio_duration: tp.Optional[float] = None, + shuffle_seed: int = 0, + load_wav: bool = True, + permutation_on_files: bool = False, ): - assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.' + assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta." assert segment_duration is None or segment_duration > 0 assert segment_duration is None or min_segment_ratio >= 0 - logging.debug(f'sample_on_duration: {sample_on_duration}') - logging.debug(f'sample_on_weight: {sample_on_weight}') - logging.debug(f'pad: {pad}') - logging.debug(f'min_segment_ratio: {min_segment_ratio}') - self.segment_duration = segment_duration self.min_segment_ratio = min_segment_ratio self.max_audio_duration = max_audio_duration @@ -317,13 +336,25 @@ class AudioDataset: self.sampling_probabilities = self._get_sampling_probabilities() self.max_read_retry = max_read_retry self.return_info = return_info + self.shuffle_seed = shuffle_seed + self.current_epoch: tp.Optional[int] = None + self.load_wav = load_wav + if not load_wav: + assert segment_duration is not None + self.permutation_on_files = permutation_on_files + if permutation_on_files: + assert not self.sample_on_duration + assert not self.sample_on_weight + assert self.shuffle + + def start_epoch(self, epoch: int): + self.current_epoch = epoch def __len__(self): return self.num_samples def _get_sampling_probabilities(self, normalized: bool = True): - """Return the sampling probabilities for each file inside `self.meta`. - """ + """Return the sampling probabilities for each file inside `self.meta`.""" scores: tp.List[float] = [] for file_meta in self.meta: score = 1. @@ -337,12 +368,32 @@ class AudioDataset: probabilities /= probabilities.sum() return probabilities - def sample_file(self, rng: torch.Generator) -> AudioMeta: - """Sample a given file from `self.meta`. Can be overriden in subclasses. + @staticmethod + @lru_cache(16) + def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int): + # Used to keep the most recent files permutation in memory implicitely. + # will work unless someone is using a lot of Datasets in parallel. + rng = torch.Generator() + rng.manual_seed(base_seed + permutation_index) + return torch.randperm(num_files, generator=rng) + + def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta: + """Sample a given file from `self.meta`. Can be overridden in subclasses. This is only called if `segment_duration` is not None. You must use the provided random number generator `rng` for reproducibility. + You can further make use of the index accessed. """ + if self.permutation_on_files: + assert self.current_epoch is not None + total_index = self.current_epoch * len(self) + index + permutation_index = total_index // len(self.meta) + relative_index = total_index % len(self.meta) + permutation = AudioDataset._get_file_permutation( + len(self.meta), permutation_index, self.shuffle_seed) + file_index = permutation[relative_index] + return self.meta[file_index] + if not self.sample_on_weight and not self.sample_on_duration: file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item()) else: @@ -350,6 +401,15 @@ class AudioDataset: return self.meta[file_index] + def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1): + # Override this method in subclass if needed. + if self.load_wav: + return audio_read(path, seek_time, duration, pad=False) + else: + assert self.segment_duration is not None + n_frames = int(self.sample_rate * self.segment_duration) + return torch.zeros(self.channels, n_frames), self.sample_rate + def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]: if self.segment_duration is None: file_meta = self.meta[index] @@ -357,18 +417,22 @@ class AudioDataset: out = convert_audio(out, sr, self.sample_rate, self.channels) n_frames = out.shape[-1] segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames, - sample_rate=self.sample_rate) + sample_rate=self.sample_rate, channels=out.shape[0]) else: rng = torch.Generator() if self.shuffle: - # We use index, plus extra randomness - rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) + # We use index, plus extra randomness, either totally random if we don't know the epoch. + # otherwise we make use of the epoch number and optional shuffle_seed. + if self.current_epoch is None: + rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) + else: + rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed)) else: # We only use index rng.manual_seed(index) for retry in range(self.max_read_retry): - file_meta = self.sample_file(rng) + file_meta = self.sample_file(index, rng) # We add some variance in the file position even if audio file is smaller than segment # without ending up with empty segments max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio) @@ -381,7 +445,7 @@ class AudioDataset: if self.pad: out = F.pad(out, (0, target_frames - n_frames)) segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames, - sample_rate=self.sample_rate) + sample_rate=self.sample_rate, channels=out.shape[0]) except Exception as exc: logger.warning("Error opening file %s: %r", file_meta.path, exc) if retry == self.max_read_retry - 1: @@ -423,7 +487,7 @@ class AudioDataset: if to_pad: # Each wav could be of a different duration as they are not segmented. for i in range(len(samples)): - # Determines the total legth of the signal with padding, so we update here as we pad. + # Determines the total length of the signal with padding, so we update here as we pad. segment_infos[i].total_frames = max_len wavs[i] = _pad_wav(wavs[i]) @@ -436,9 +500,7 @@ class AudioDataset: return torch.stack(samples) def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: - """Filters out audio files with short durations. - Removes from meta files that have durations that will not allow to samples examples from them. - """ + """Filters out audio files with audio durations that will not allow to sample examples from them.""" orig_len = len(meta) # Filter data that is too short. diff --git a/audiocraft/data/audio_utils.py b/audiocraft/data/audio_utils.py index 7595435329587d7fe97afbff5f74664a808ea050..11c0c3c043907c4dbbac09e29af805da20fd1a01 100644 --- a/audiocraft/data/audio_utils.py +++ b/audiocraft/data/audio_utils.py @@ -3,7 +3,8 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - +"""Various utilities for audio convertion (pcm format, sample rate and channels), +and volume normalization.""" import sys import typing as tp @@ -150,17 +151,19 @@ def f32_pcm(wav: torch.Tensor) -> torch.Tensor: """ if wav.dtype.is_floating_point: return wav - else: - assert wav.dtype == torch.int16 + elif wav.dtype == torch.int16: return wav.float() / 2**15 + elif wav.dtype == torch.int32: + return wav.float() / 2**31 + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") def i16_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to int 16 bits PCM format. - ..Warning:: There exist many formula for doing this convertion. None are perfect - due to the asymetry of the int16 range. One either have possible clipping, DC offset, - or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom, + ..Warning:: There exist many formula for doing this conversion. None are perfect + due to the asymmetry of the int16 range. One either have possible clipping, DC offset, + or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom, it is possible that `i16_pcm(f32_pcm)) != Identity`. """ if wav.dtype.is_floating_point: diff --git a/audiocraft/data/info_audio_dataset.py b/audiocraft/data/info_audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..47ab4b1594faf1e9f1ce962fb980d80295b1f079 --- /dev/null +++ b/audiocraft/data/info_audio_dataset.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Base classes for the datasets that also provide non-audio metadata, +e.g. description, text transcription etc. +""" +from dataclasses import dataclass +import logging +import math +import re +import typing as tp + +import torch + +from .audio_dataset import AudioDataset, AudioMeta +from ..environment import AudioCraftEnvironment +from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes + + +logger = logging.getLogger(__name__) + + +def _clusterify_meta(meta: AudioMeta) -> AudioMeta: + """Monkey-patch meta to match cluster specificities.""" + meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path) + if meta.info_path is not None: + meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path) + return meta + + +def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: + """Monkey-patch all meta to match cluster specificities.""" + return [_clusterify_meta(m) for m in meta] + + +@dataclass +class AudioInfo(SegmentWithAttributes): + """Dummy SegmentInfo with empty attributes. + + The InfoAudioDataset is expected to return metadata that inherits + from SegmentWithAttributes class and can return conditioning attributes. + + This basically guarantees all datasets will be compatible with current + solver that contain conditioners requiring this. + """ + audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM. + + def to_condition_attributes(self) -> ConditioningAttributes: + return ConditioningAttributes() + + +class InfoAudioDataset(AudioDataset): + """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform. + + See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments. + """ + def __init__(self, meta: tp.List[AudioMeta], **kwargs): + super().__init__(clusterify_all_meta(meta), **kwargs) + + def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]: + if not self.return_info: + wav = super().__getitem__(index) + assert isinstance(wav, torch.Tensor) + return wav + wav, meta = super().__getitem__(index) + return wav, AudioInfo(**meta.to_dict()) + + +def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]: + """Preprocess a single keyword or possible a list of keywords.""" + if isinstance(value, list): + return get_keyword_list(value) + else: + return get_keyword(value) + + +def get_string(value: tp.Optional[str]) -> tp.Optional[str]: + """Preprocess a single keyword.""" + if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': + return None + else: + return value.strip() + + +def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: + """Preprocess a single keyword.""" + if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': + return None + else: + return value.strip().lower() + + +def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]: + """Preprocess a list of keywords.""" + if isinstance(values, str): + values = [v.strip() for v in re.split(r'[,\s]', values)] + elif isinstance(values, float) and math.isnan(values): + values = [] + if not isinstance(values, list): + logger.debug(f"Unexpected keyword list {values}") + values = [str(values)] + + kws = [get_keyword(v) for v in values] + kw_list = [k for k in kws if k is not None] + if len(kw_list) == 0: + return None + else: + return kw_list diff --git a/audiocraft/data/zip.py b/audiocraft/data/zip.py index 1f1154231da321dd38d151ff285dbcff5e38a6e0..f0b17849d36991e7def35a14d3d518b9d867ce36 100644 --- a/audiocraft/data/zip.py +++ b/audiocraft/data/zip.py @@ -3,6 +3,8 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Utility for reading some info from inside a zip file. +""" import typing import zipfile @@ -18,13 +20,13 @@ MODE = Literal['r', 'w', 'x', 'a'] @dataclass(order=True) class PathInZip: - """Class for holding a path of file within a zip file. + """Hold a path of file within a zip file. Args: - path: The convention is : + path (str): The convention is :. Let's assume there is a zip file /some/location/foo.zip and inside of it is a json file located at /data/file1.json, - Then we expect path = "/some/location/foo.zip:/data/file1.json" + Then we expect path = "/some/location/foo.zip:/data/file1.json". """ INFO_PATH_SEP = ':' @@ -55,7 +57,7 @@ def set_zip_cache_size(max_size: int): """Sets the maximal LRU caching for zip file opening. Args: - max_size: the maximal LRU cache. + max_size (int): the maximal LRU cache. """ global _cached_open_zip _cached_open_zip = lru_cache(max_size)(_open_zip) @@ -65,8 +67,8 @@ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: """Opens a file stored inside a zip and returns a file-like object. Args: - path_in_zip: A PathInZip object representing the file to return a file-like object of. - mode: The mode in which to open the file with. + path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of. + mode (str): The mode in which to open the file with. Returns: A file-like object for PathInZip. """ diff --git a/audiocraft/environment.py b/audiocraft/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..adc7819305758bb50a9984928bfa7f13eabef5f5 --- /dev/null +++ b/audiocraft/environment.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Provides cluster and tools configuration across clusters (slurm, dora, utilities). +""" + +import logging +import os +from pathlib import Path +import re +import typing as tp + +import omegaconf + +from .utils.cluster import _guess_cluster_type + + +logger = logging.getLogger(__name__) + + +class AudioCraftEnvironment: + """Environment configuration for teams and clusters. + + AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment + or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment + provides pointers to a reference folder resolved automatically across clusters that is shared across team members, + allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically + map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters. + + The cluster type is identified automatically and base configuration file is read from config/teams.yaml. + Use the following environment variables to specify the cluster, team or configuration: + + AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type + cannot be inferred automatically. + AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration. + If not set, configuration is read from config/teams.yaml. + AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team. + Cluster configuration are shared across teams to match compute allocation, + specify your cluster configuration in the configuration file under a key mapping + your team name. + """ + _instance = None + DEFAULT_TEAM = "default" + + def __init__(self) -> None: + """Loads configuration.""" + self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM) + cluster_type = _guess_cluster_type() + cluster = os.getenv( + "AUDIOCRAFT_CLUSTER", cluster_type.value + ) + logger.info("Detecting cluster type %s", cluster_type) + + self.cluster: str = cluster + + config_path = os.getenv( + "AUDIOCRAFT_CONFIG", + Path(__file__) + .parent.parent.joinpath("config/teams", self.team) + .with_suffix(".yaml"), + ) + self.config = omegaconf.OmegaConf.load(config_path) + self._dataset_mappers = [] + cluster_config = self._get_cluster_config() + if "dataset_mappers" in cluster_config: + for pattern, repl in cluster_config["dataset_mappers"].items(): + regex = re.compile(pattern) + self._dataset_mappers.append((regex, repl)) + + def _get_cluster_config(self) -> omegaconf.DictConfig: + assert isinstance(self.config, omegaconf.DictConfig) + return self.config[self.cluster] + + @classmethod + def instance(cls): + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset(cls): + """Clears the environment and forces a reload on next invocation.""" + cls._instance = None + + @classmethod + def get_team(cls) -> str: + """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var. + If not defined, defaults to "labs". + """ + return cls.instance().team + + @classmethod + def get_cluster(cls) -> str: + """Gets the detected cluster. + This value can be overridden by the AUDIOCRAFT_CLUSTER env var. + """ + return cls.instance().cluster + + @classmethod + def get_dora_dir(cls) -> Path: + """Gets the path to the dora directory for the current team and cluster. + Value is overridden by the AUDIOCRAFT_DORA_DIR env var. + """ + cluster_config = cls.instance()._get_cluster_config() + dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"]) + logger.warning(f"Dora directory: {dora_dir}") + return Path(dora_dir) + + @classmethod + def get_reference_dir(cls) -> Path: + """Gets the path to the reference directory for the current team and cluster. + Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var. + """ + cluster_config = cls.instance()._get_cluster_config() + return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"])) + + @classmethod + def get_slurm_exclude(cls) -> tp.Optional[str]: + """Get the list of nodes to exclude for that cluster.""" + cluster_config = cls.instance()._get_cluster_config() + return cluster_config.get("slurm_exclude") + + @classmethod + def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str: + """Gets the requested partitions for the current team and cluster as a comma-separated string. + + Args: + partition_types (list[str], optional): partition types to retrieve. Values must be + from ['global', 'team']. If not provided, the global partition is returned. + """ + if not partition_types: + partition_types = ["global"] + + cluster_config = cls.instance()._get_cluster_config() + partitions = [ + cluster_config["partitions"][partition_type] + for partition_type in partition_types + ] + return ",".join(partitions) + + @classmethod + def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path: + """Converts reference placeholder in path with configured reference dir to resolve paths. + + Args: + path (str or Path): Path to resolve. + Returns: + Path: Resolved path. + """ + path = str(path) + + if path.startswith("//reference"): + reference_dir = cls.get_reference_dir() + logger.warn(f"Reference directory: {reference_dir}") + assert ( + reference_dir.exists() and reference_dir.is_dir() + ), f"Reference directory does not exist: {reference_dir}." + path = re.sub("^//reference", str(reference_dir), path) + + return Path(path) + + @classmethod + def apply_dataset_mappers(cls, path: str) -> str: + """Applies dataset mapping regex rules as defined in the configuration. + If no rules are defined, the path is returned as-is. + """ + instance = cls.instance() + + for pattern, repl in instance._dataset_mappers: + path = pattern.sub(repl, path) + + return path diff --git a/audiocraft/models/__init__.py b/audiocraft/models/__init__.py index 92c7a48a200eba455044cd66e0d2c1efe6494f5c..ceefaa1599de2d610c6dd8a898888ba52674b56b 100644 --- a/audiocraft/models/__init__.py +++ b/audiocraft/models/__init__.py @@ -4,7 +4,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +""" +Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. +""" # flake8: noqa +from . import builders, loaders +from .encodec import ( + CompressionModel, EncodecModel, DAC, + HFEncodecModel, HFEncodecCompressionModel) from .musicgen import MusicGen from .lm import LMModel from .encodec import CompressionModel, EncodecModel diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py index 77ee5f96fea2e3c9e475fe961bc1a5ee473ed8eb..b7144874457e569d6e25fe30cafa0cddc1dd59a1 100644 --- a/audiocraft/models/builders.py +++ b/audiocraft/models/builders.py @@ -10,32 +10,34 @@ from the Hydra config. """ import typing as tp -import warnings import audiocraft import omegaconf import torch -from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa +from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel from .lm import LMModel from ..modules.codebooks_patterns import ( CodebooksPatternProvider, DelayedPatternProvider, + MusicLMPattern, ParallelPatternProvider, UnrolledPatternProvider, - VALLEPattern, - MusicLMPattern, + CoarseFirstPattern, ) from ..modules.conditioners import ( BaseConditioner, + ChromaStemConditioner, + CLAPEmbeddingConditioner, + ConditionFuser, ConditioningProvider, LUTConditioner, T5Conditioner, - ConditionFuser, - ChromaStemConditioner, ) +from .unet import DiffusionUnet from .. import quantization as qt from ..utils.utils import dict_from_config +from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer: @@ -60,12 +62,11 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs) return encoder, decoder else: - raise KeyError(f'Unexpected compression model {cfg.compression_model}') + raise KeyError(f"Unexpected compression model {cfg.compression_model}") def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: - """Instantiate a compression model. - """ + """Instantiate a compression model.""" if cfg.compression_model == 'encodec': kwargs = dict_from_config(getattr(cfg, 'encodec')) encoder_name = kwargs.pop('autoencoder') @@ -73,20 +74,17 @@ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', None) - renorm = kwargs.pop('renorm') - if renormalize is None: - renormalize = renorm is not None - warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.") + renormalize = kwargs.pop('renormalize', False) + # deprecated params + kwargs.pop('renorm', None) return EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) else: - raise KeyError(f'Unexpected compression model {cfg.compression_model}') + raise KeyError(f"Unexpected compression model {cfg.compression_model}") def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: - """Instantiate a transformer LM. - """ + """Instantiate a transformer LM.""" if cfg.lm_model == 'transformer_lm': kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) n_q = kwargs['n_q'] @@ -94,14 +92,14 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) - cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"] + cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] fuser = get_condition_fuser(cfg) condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) - if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically + if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically kwargs['cross_attention'] = True if codebooks_pattern_cfg.modeling is None: assert q_modeling is not None, \ - 'LM model should either have a codebook pattern defined or transformer_lm.q_modeling' + "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" codebooks_pattern_cfg = omegaconf.OmegaConf.create( {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} ) @@ -118,45 +116,50 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: **kwargs ).to(cfg.device) else: - raise KeyError(f'Unexpected LM model {cfg.lm_model}') + raise KeyError(f"Unexpected LM model {cfg.lm_model}") def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: - """Instantiate a conditioning model. - """ + """Instantiate a conditioning model.""" device = cfg.device duration = cfg.dataset.segment_duration - cfg = getattr(cfg, "conditioners") - cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg + cfg = getattr(cfg, 'conditioners') + dict_cfg = {} if cfg is None else dict_from_config(cfg) conditioners: tp.Dict[str, BaseConditioner] = {} - with omegaconf.open_dict(cfg): - condition_provider_args = cfg.pop('args', {}) - for cond, cond_cfg in cfg.items(): - model_type = cond_cfg["model"] + condition_provider_args = dict_cfg.pop('args', {}) + condition_provider_args.pop('merge_text_conditions_p', None) + condition_provider_args.pop('drop_desc_p', None) + + for cond, cond_cfg in dict_cfg.items(): + model_type = cond_cfg['model'] model_args = cond_cfg[model_type] - if model_type == "t5": + if model_type == 't5': conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) - elif model_type == "lut": + elif model_type == 'lut': conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args) - elif model_type == "chroma_stem": - model_args.pop('cache_path', None) + elif model_type == 'chroma_stem': conditioners[str(cond)] = ChromaStemConditioner( output_dim=output_dim, duration=duration, device=device, **model_args ) + elif model_type == 'clap': + conditioners[str(cond)] = CLAPEmbeddingConditioner( + output_dim=output_dim, + device=device, + **model_args + ) else: - raise ValueError(f"unrecognized conditioning model: {model_type}") + raise ValueError(f"Unrecognized conditioning model: {model_type}") conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) return conditioner def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: - """Instantiate a condition fuser object. - """ - fuser_cfg = getattr(cfg, "fuser") - fuser_methods = ["sum", "cross", "prepend", "input_interpolate"] + """Instantiate a condition fuser object.""" + fuser_cfg = getattr(cfg, 'fuser') + fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate'] fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) @@ -164,13 +167,12 @@ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: - """Instantiate a codebooks pattern provider object. - """ + """Instantiate a codebooks pattern provider object.""" pattern_providers = { 'parallel': ParallelPatternProvider, 'delay': DelayedPatternProvider, 'unroll': UnrolledPatternProvider, - 'valle': VALLEPattern, + 'coarse_first': CoarseFirstPattern, 'musiclm': MusicLMPattern, } name = cfg.modeling @@ -179,14 +181,20 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb return klass(n_q, **kwargs) -def get_debug_compression_model(device='cpu'): - """Instantiate a debug compression model to be used for unit tests. - """ - seanet_kwargs = { +def get_debug_compression_model(device='cpu', sample_rate: int = 32000): + """Instantiate a debug compression model to be used for unit tests.""" + assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model" + model_ratios = { + 16000: [10, 8, 8], # 25 Hz at 16kHz + 32000: [10, 8, 16] # 25 Hz at 32kHz + } + ratios: tp.List[int] = model_ratios[sample_rate] + frame_rate = 25 + seanet_kwargs: dict = { 'n_filters': 4, 'n_residual_layers': 1, 'dimension': 32, - 'ratios': [10, 8, 16] # 25 Hz at 32kHz + 'ratios': ratios, } encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs) decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs) @@ -195,13 +203,31 @@ def get_debug_compression_model(device='cpu'): quantizer(init_x, 1) # initialize kmeans etc. compression_model = EncodecModel( encoder, decoder, quantizer, - frame_rate=25, sample_rate=32000, channels=1).to(device) + frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device) return compression_model.eval() +def get_diffusion_model(cfg: omegaconf.DictConfig): + # TODO Find a way to infer the channels from dset + channels = cfg.channels + num_steps = cfg.schedule.num_steps + return DiffusionUnet( + chin=channels, num_steps=num_steps, **cfg.diffusion_unet) + + +def get_processor(cfg, sample_rate: int = 24000): + sample_processor = SampleProcessor() + if cfg.use: + kw = dict(cfg) + kw.pop('use') + kw.pop('name') + if cfg.name == "multi_band_processor": + sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw) + return sample_processor + + def get_debug_lm_model(device='cpu'): - """Instantiate a debug LM to be used for unit tests. - """ + """Instantiate a debug LM to be used for unit tests.""" pattern = DelayedPatternProvider(n_q=4) dim = 16 providers = { @@ -216,3 +242,17 @@ def get_debug_lm_model(device='cpu'): n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2, cross_attention=True, causal=True) return lm.to(device).eval() + + +def get_wrapped_compression_model( + compression_model: CompressionModel, + cfg: omegaconf.DictConfig) -> CompressionModel: + if hasattr(cfg, 'interleave_stereo_codebooks'): + if cfg.interleave_stereo_codebooks.use: + kwargs = dict_from_config(cfg.interleave_stereo_codebooks) + kwargs.pop('use') + compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs) + if hasattr(cfg, 'compression_model_n_q'): + if cfg.compression_model_n_q is not None: + compression_model.set_num_codebooks(cfg.compression_model_n_q) + return compression_model diff --git a/audiocraft/models/encodec.py b/audiocraft/models/encodec.py index 69621a695887b0b41614c51cae020f6fd0af221d..d4e77a941ef6b45ca54933afc6e430a75390013c 100644 --- a/audiocraft/models/encodec.py +++ b/audiocraft/models/encodec.py @@ -3,18 +3,32 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Compression models or wrapper around existing models. +Also defines the main interface that a model must follow to be usable as an audio tokenizer. +""" from abc import ABC, abstractmethod +import logging +import math +from pathlib import Path import typing as tp from einops import rearrange +import numpy as np import torch from torch import nn +from transformers import EncodecModel as HFEncodecModel from .. import quantization as qt +logger = logging.getLogger() + + class CompressionModel(ABC, nn.Module): + """Base API for all compression model that aim at being used as audio tokenizers + with a language model. + """ @abstractmethod def forward(self, x: torch.Tensor) -> qt.QuantizedResult: @@ -22,12 +36,17 @@ class CompressionModel(ABC, nn.Module): @abstractmethod def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - """See `EncodecModel.encode`""" + """See `EncodecModel.encode`.""" ... @abstractmethod def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): - """See `EncodecModel.decode`""" + """See `EncodecModel.decode`.""" + ... + + @abstractmethod + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" ... @property @@ -37,7 +56,7 @@ class CompressionModel(ABC, nn.Module): @property @abstractmethod - def frame_rate(self) -> int: + def frame_rate(self) -> float: ... @property @@ -62,10 +81,46 @@ class CompressionModel(ABC, nn.Module): @abstractmethod def set_num_codebooks(self, n: int): - """Set the active number of codebooks used by the quantizer. - """ + """Set the active number of codebooks used by the quantizer.""" ... + @staticmethod + def get_pretrained( + name: str, device: tp.Union[torch.device, str] = 'cpu' + ) -> 'CompressionModel': + """Instantiate a CompressionModel from a given pretrained model. + + Args: + name (Path or str): name of the pretrained model. See after. + device (torch.device or str): Device on which the model is loaded. + + Pretrained models: + - dac_44khz (https://github.com/descriptinc/descript-audio-codec) + - dac_24khz (same) + - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz) + - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz) + - your own model on HugginFace. Export instructions to come... + """ + + from . import builders, loaders + model: CompressionModel + if name in ['dac_44khz', 'dac_24khz']: + model_type = name.split('_')[1] + logger.info("Getting pretrained compression model from DAC %s", model_type) + model = DAC(model_type) + elif name in ['debug_compression_model']: + logger.info("Getting pretrained compression model for debug") + model = builders.get_debug_compression_model() + elif Path(name).exists(): + # We assume here if the paths exist that it is in fact an AC checkpoint + # that was exported using `audiocraft.utils.export` functions. + model = loaders.load_compression_model(name, device=device) + else: + logger.info("Getting pretrained compression model from HF %s", name) + hf_model = HFEncodecModel.from_pretrained(name) + model = HFEncodecCompressionModel(hf_model).to(device) + return model.to(device).eval() + class EncodecModel(CompressionModel): """Encodec model operating on the raw waveform. @@ -80,9 +135,9 @@ class EncodecModel(CompressionModel): causal (bool): Whether to use a causal version of the model. renormalize (bool): Whether to renormalize the audio before running the model. """ - # we need assignement to override the property in the abstract class, + # we need assignment to override the property in the abstract class, # I couldn't find a better way... - frame_rate: int = 0 + frame_rate: float = 0 sample_rate: int = 0 channels: int = 0 @@ -111,25 +166,21 @@ class EncodecModel(CompressionModel): @property def total_codebooks(self): - """Total number of quantizer codebooks available. - """ + """Total number of quantizer codebooks available.""" return self.quantizer.total_codebooks @property def num_codebooks(self): - """Active number of codebooks used by the quantizer. - """ + """Active number of codebooks used by the quantizer.""" return self.quantizer.num_codebooks def set_num_codebooks(self, n: int): - """Set the active number of codebooks used by the quantizer. - """ + """Set the active number of codebooks used by the quantizer.""" self.quantizer.set_num_codebooks(n) @property def cardinality(self): - """Cardinality of each codebook. - """ + """Cardinality of each codebook.""" return self.quantizer.bins def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: @@ -176,7 +227,7 @@ class EncodecModel(CompressionModel): x (torch.Tensor): Float tensor of shape [B, C, T] Returns: - codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of: + codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. scale a float tensor containing the scale for audio renormalizealization. """ @@ -192,41 +243,174 @@ class EncodecModel(CompressionModel): Args: codes (torch.Tensor): Int tensor of shape [B, K, T] - scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value. + scale (torch.Tensor, optional): Float tensor containing the scale value. Returns: out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. """ - emb = self.quantizer.decode(codes) + emb = self.decode_latent(codes) out = self.decoder(emb) out = self.postprocess(out, scale) # out contains extra padding added by the encoder and decoder return out + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.quantizer.decode(codes) + + +class DAC(CompressionModel): + def __init__(self, model_type: str = "44khz"): + super().__init__() + try: + import dac.utils + except ImportError: + raise RuntimeError("Could not import dac, make sure it is installed, " + "please run `pip install descript-audio-codec`") + self.model = dac.utils.load_model(model_type=model_type) + self.n_quantizers = self.total_codebooks + self.model.eval() + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + codes = self.model.encode(x, self.n_quantizers)[1] + return codes[:, :self.n_quantizers], None + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + assert scale is None + z_q = self.decode_latent(codes) + return self.model.decode(z_q) + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.model.quantizer.from_codes(codes)[0] + + @property + def channels(self) -> int: + return 1 + + @property + def frame_rate(self) -> float: + return self.model.sample_rate / self.model.hop_length + + @property + def sample_rate(self) -> int: + return self.model.sample_rate + + @property + def cardinality(self) -> int: + return self.model.codebook_size + + @property + def num_codebooks(self) -> int: + return self.n_quantizers -class FlattenedCompressionModel(CompressionModel): - """Wraps a CompressionModel and flatten its codebooks, e.g. - instead of returning [B, K, T], return [B, S, T * (K // S)] with - S the number of codebooks per step, and `K // S` the number of 'virtual steps' - for each real time step. + @property + def total_codebooks(self) -> int: + return self.model.n_codebooks + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n + + +class HFEncodecCompressionModel(CompressionModel): + """Wrapper around HuggingFace Encodec. + """ + def __init__(self, model: HFEncodecModel): + super().__init__() + self.model = model + bws = self.model.config.target_bandwidths + num_codebooks = [ + bw * 1000 / (self.frame_rate * math.log2(self.cardinality)) + for bw in bws + ] + deltas = [nc - int(nc) for nc in num_codebooks] + # Checking we didn't do some bad maths and we indeed have integers! + assert all(deltas) <= 1e-3, deltas + self.possible_num_codebooks = [int(nc) for nc in num_codebooks] + self.set_num_codebooks(max(self.possible_num_codebooks)) + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + # We don't support training with this. + raise NotImplementedError("Forward and training with HF EncodecModel not supported.") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks) + bandwidth = self.model.config.target_bandwidths[bandwidth_index] + res = self.model.encode(x, None, bandwidth) + assert len(res[0]) == 1 + assert len(res[1]) == 1 + return res[0][0], res[1][0] + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + if scale is None: + scales = [None] # type: ignore + else: + scales = scale # type: ignore + res = self.model.decode(codes[None], scales) + return res[0] + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.model.quantizer.decode(codes.transpose(0, 1)) + + @property + def channels(self) -> int: + return self.model.config.audio_channels + + @property + def frame_rate(self) -> float: + hop_length = int(np.prod(self.model.config.upsampling_ratios)) + return self.sample_rate / hop_length + + @property + def sample_rate(self) -> int: + return self.model.config.sampling_rate + + @property + def cardinality(self) -> int: + return self.model.config.codebook_size + + @property + def num_codebooks(self) -> int: + return self._num_codebooks + + @property + def total_codebooks(self) -> int: + return max(self.possible_num_codebooks) + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + if n not in self.possible_num_codebooks: + raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}") + self._num_codebooks = n + + +class InterleaveStereoCompressionModel(CompressionModel): + """Wraps a CompressionModel to support stereo inputs. The wrapped model + will be applied independently to the left and right channels, and both codebooks + will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per + channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on + `per_timestep`. Args: - model (CompressionModel): compression model to wrap. - codebooks_per_step (int): number of codebooks to keep per step, - this must divide the number of codebooks provided by the wrapped model. - extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1, - if each codebook has a cardinality N, then the first codebook will - use the range [0, N - 1], and the second [N, 2 N - 1] etc. - On decoding, this can lead to potentially invalid sequences. - Any invalid entry will be silently remapped to the proper range - with a modulo. + model (CompressionModel): Compression model to wrap. + per_timestep (bool): Whether to interleave on the timestep dimension + or on the codebooks dimension. """ - def __init__(self, model: CompressionModel, codebooks_per_step: int = 1, - extend_cardinality: bool = True): + def __init__(self, model: CompressionModel, per_timestep: bool = False): super().__init__() self.model = model - self.codebooks_per_step = codebooks_per_step - self.extend_cardinality = extend_cardinality + self.per_timestep = per_timestep + assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio" @property def total_codebooks(self): @@ -236,30 +420,27 @@ class FlattenedCompressionModel(CompressionModel): def num_codebooks(self): """Active number of codebooks used by the quantizer. - ..Warning:: this reports the number of codebooks after the flattening + ..Warning:: this reports the number of codebooks after the interleaving of the codebooks! """ - assert self.model.num_codebooks % self.codebooks_per_step == 0 - return self.codebooks_per_step + return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2 def set_num_codebooks(self, n: int): """Set the active number of codebooks used by the quantizer. - ..Warning:: this sets the number of codebooks **before** the flattening - of the codebooks. + ..Warning:: this sets the number of codebooks before the interleaving! """ - assert n % self.codebooks_per_step == 0 self.model.set_num_codebooks(n) @property - def num_virtual_steps(self) -> int: + def num_virtual_steps(self) -> float: """Return the number of virtual steps, e.g. one real step will be split into that many steps. """ - return self.model.num_codebooks // self.codebooks_per_step + return 2 if self.per_timestep else 1 @property - def frame_rate(self) -> int: + def frame_rate(self) -> float: return self.model.frame_rate * self.num_virtual_steps @property @@ -268,35 +449,58 @@ class FlattenedCompressionModel(CompressionModel): @property def channels(self) -> int: - return self.model.channels + return 2 @property def cardinality(self): """Cardinality of each codebook. """ - if self.extend_cardinality: - return self.model.cardinality * self.num_virtual_steps - else: - return self.model.cardinality + return self.model.cardinality def forward(self, x: torch.Tensor) -> qt.QuantizedResult: raise NotImplementedError("Not supported, use encode and decode.") def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - indices, scales = self.model.encode(x) - B, K, T = indices.shape - indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step) - if self.extend_cardinality: - for virtual_step in range(1, self.num_virtual_steps): - indices[..., virtual_step] += self.model.cardinality * virtual_step - indices = rearrange(indices, 'b k t v -> b k (t v)') + B, C, T = x.shape + assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}" + + indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1)) + indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1)) + indices = torch.stack([indices_c0, indices_c1], dim=0) + scales: tp.Optional[torch.Tensor] = None + if scales_c0 is not None and scales_c1 is not None: + scales = torch.stack([scales_c0, scales_c1], dim=1) + + if self.per_timestep: + indices = rearrange(indices, 'c b k t -> b k (t c)', c=2) + else: + indices = rearrange(indices, 'c b k t -> b (k c) t', c=2) + return (indices, scales) + def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: + if self.per_timestep: + codes = rearrange(codes, 'b k (t c) -> c b k t', c=2) + else: + codes = rearrange(codes, 'b (k c) t -> c b k t', c=2) + return codes[0], codes[1] + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): B, K, T = codes.shape - assert T % self.num_virtual_steps == 0 - codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps) - # We silently ignore potential errors from the LM when - # using extend_cardinality. - codes = codes % self.model.cardinality - return self.model.decode(codes, scale) + assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match" + assert K == self.num_codebooks, "Provided codes' number of codebooks does not match" + + scale_c0, scale_c1 = None, None + if scale is not None: + assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}" + scale_c0 = scale[0, ...] + scale_c1 = scale[1, ...] + + codes_c0, codes_c1 = self.get_left_right_codes(codes) + audio_c0 = self.model.decode(codes_c0, scale_c0) + audio_c1 = self.model.decode(codes_c1, scale_c1) + return torch.cat([audio_c0, audio_c1], dim=1) + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + raise NotImplementedError("Not supported by interleaved stereo wrapped models.") diff --git a/audiocraft/models/lm.py b/audiocraft/models/lm.py index ea59c5a3cb1e9f5c88bdee8b26c2540a45a509d5..c2be2559251f5a300f5b582a234f6752ebe6094d 100644 --- a/audiocraft/models/lm.py +++ b/audiocraft/models/lm.py @@ -41,7 +41,7 @@ def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None method (str): Method name for init function. Valid options are: 'gaussian', 'uniform'. input_dim (int): Input dimension of the initialized module. - init_depth (Optional[int]): Optional init depth value used to rescale + init_depth (int, optional): Optional init depth value used to rescale the standard deviation if defined. """ # Compute std @@ -70,7 +70,7 @@ def init_layer(m: nn.Module, Args: m (nn.Module): Module to initialize. method (str): Method name for the init function. - init_depth (Optional[int]): Optional init depth value used to rescale + init_depth (int, optional): Optional init depth value used to rescale the standard deviation if defined. zero_bias_init (bool): Whether to initialize the bias to 0 or not. """ @@ -130,10 +130,10 @@ class LMModel(StreamingModule): hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. norm (str): Normalization method. norm_first (bool): Use pre-norm instead of post-norm. - emb_lr (Optional[float]): Embedding-specific learning rate. + emb_lr (float, optional): Embedding-specific learning rate. bias_proj (bool): Use bias for output projections. - weight_init (Optional[str]): Method for weight initialization. - depthwise_init (Optional[str]): Method for depthwise weight initialization. + weight_init (str, optional): Method for weight initialization. + depthwise_init (str, optional): Method for depthwise weight initialization. zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. cfg_dropout (float): Classifier-free guidance dropout. cfg_coef (float): Classifier-free guidance coefficient. @@ -179,11 +179,11 @@ class LMModel(StreamingModule): """Initialization of the transformer module weights. Args: - weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options. - depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid: + weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. + depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: 'current' where the depth corresponds to the current layer index or 'global' where the total number of layer is used as depth. If not set, no depthwise initialization strategy is used. - zero_bias_init (bool): Whether to initalize bias to zero or not. + zero_bias_init (bool): Whether to initialize bias to zero or not. """ assert depthwise_init is None or depthwise_init in ['current', 'global'] assert depthwise_init is None or weight_init is not None, \ @@ -225,17 +225,17 @@ class LMModel(StreamingModule): S the sequence steps, return the logits with shape [B, card, K, S]. Args: - indices (torch.Tensor): indices of the codes to model. - conditions (list[ConditioningAttributes]): conditionings to use when modeling + indices (torch.Tensor): Indices of the codes to model. + conditions (list of ConditioningAttributes): Conditions to use when modeling the given codes. Note that when evaluating multiple time with the same conditioning you should pre-compute those and pass them as `condition_tensors`. - condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning + condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning tensors, see `conditions`. Returns: torch.Tensor: Logits. """ B, K, S = sequence.shape - assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks' + assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks" input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)]) if condition_tensors is None: assert not self._is_streaming, "Conditions tensors should be precomputed when streaming." @@ -271,10 +271,10 @@ class LMModel(StreamingModule): Args: codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, K the number of codebooks and T the number of timesteps. - conditions (list[ConditioningAttributes]): conditionings to use when modeling + conditions (list of ConditioningAttributes): conditionings to use when modeling the given codes. Note that when evaluating multiple time with the same conditioning you should pre-compute those and pass them as `condition_tensors`. - condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning + condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning tensors, see `conditions`. Returns: LMOutput: Language model outputs @@ -314,7 +314,8 @@ class LMModel(StreamingModule): temp: float = 1.0, top_k: int = 0, top_p: float = 0.0, - cfg_coef: tp.Optional[float] = None) -> torch.Tensor: + cfg_coef: tp.Optional[float] = None, + two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor: """Sample next token from the model given a sequence and a set of conditions. The model supports multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). @@ -322,21 +323,22 @@ class LMModel(StreamingModule): sequence (torch.Tensor): Current sequence of shape [B, K, S] with K corresponding to the number of codebooks and S the number of sequence steps. S = 1 in streaming mode, except for the first step that contains a bigger prompt. - condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used, + condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used, should be twice the batch size, being the concatenation of the conditions + null conditions. use_sampling (bool): Whether to use a sampling strategy or not. temp (float): Sampling temperature. top_k (int): K for "top-k" sampling. top_p (float): P for "top-p" sampling. - cfg_coef (float): classifier free guidance coefficient + cfg_coef (float, optional): classifier free guidance coefficient Returns: next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. """ B = sequence.shape[0] cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef model = self if self._fsdp is None else self._fsdp - if self.two_step_cfg and cfg_conditions != {}: - assert isinstance(cfg_conditions, tuple) + two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg + if two_step_cfg and cfg_conditions != {}: + assert isinstance(cfg_conditions, tuple), type(cfg_conditions) condition_tensors, null_condition_tensors = cfg_conditions cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors) state = self.get_streaming_state() @@ -388,7 +390,7 @@ class LMModel(StreamingModule): top_k: int = 250, top_p: float = 0.0, cfg_coef: tp.Optional[float] = None, - two_step_cfg: bool = False, + two_step_cfg: tp.Optional[bool] = None, remove_prompts: bool = False, check: bool = False, callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor: @@ -396,15 +398,19 @@ class LMModel(StreamingModule): be perform in a greedy fashion or using sampling with top K and top P strategies. Args: - prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T]. - conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None. - num_samples (int or None): Number of samples to generate when no prompt and no conditions are given. + prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. + conditions_tensors (list of ConditioningAttributes, optional): List of conditions. + num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given. max_gen_len (int): Maximum generation length. use_sampling (bool): Whether to use a sampling strategy or not. temp (float): Sampling temperature. top_k (int): K for "top-k" sampling. top_p (float): P for "top-p" sampling. + cfg_coeff (float, optional): Classifier-free guidance coefficient. + two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation. remove_prompts (bool): Whether to remove prompts from generation or not. + check (bool): Whether to apply further checks on generated sequence. + callback (Callback, optional): Callback function to report generation progress. Returns: torch.Tensor: Generated tokens. """ @@ -412,7 +418,7 @@ class LMModel(StreamingModule): first_param = next(iter(self.parameters())) device = first_param.device - # Checking all input shapes are consistents. + # Checking all input shapes are consistent. possible_num_samples = [] if num_samples is not None: possible_num_samples.append(num_samples) @@ -422,7 +428,7 @@ class LMModel(StreamingModule): possible_num_samples.append(len(conditions)) else: possible_num_samples.append(1) - assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes" + assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" num_samples = possible_num_samples[0] # below we create set of conditions: one conditional and one unconditional @@ -432,7 +438,7 @@ class LMModel(StreamingModule): # 1. it is about x2 faster than doing 2 forward passes # 2. avoid the streaming API treating the 2 passes as part of different time steps # We also support doing two different passes, in particular to ensure that - # the padding structure is exactly the same between train anf test. + # the padding structure is exactly the same between train and test. # With a batch size of 1, this can be slower though. cfg_conditions: CFGConditions two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg @@ -457,8 +463,8 @@ class LMModel(StreamingModule): B, K, T = prompt.shape start_offset = T print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}") - assert start_offset <= max_gen_len - + assert start_offset < max_gen_len + pattern = self.pattern_provider.get_pattern(max_gen_len) # this token is used as default value for codes that are not generated yet unknown_token = -1 @@ -490,7 +496,7 @@ class LMModel(StreamingModule): # sample next token from the model, next token shape is [B, K, 1] next_token = self._sample_next_token( curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, - cfg_coef=cfg_coef) + cfg_coef=cfg_coef, two_step_cfg=two_step_cfg) # ensure the tokens that should be masked are properly set to special_token_id # as the model never output special_token_id valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index b6cbf13db1d9e79f21cf649aa25ce8c3e36adfec..c9b4bf69ba93c96717e2c72434d4de06741a4cad 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -24,10 +24,16 @@ from huggingface_hub import hf_hub_download import typing as tp import os -from omegaconf import OmegaConf +from omegaconf import OmegaConf, DictConfig import torch +import audiocraft from . import builders +from .encodec import CompressionModel + + +def get_audiocraft_cache_dir() -> tp.Optional[str]: + return os.environ.get('AUDIOCRAFT_CACHE_DIR', None) HF_MODEL_CHECKPOINTS_MAP = { @@ -50,6 +56,8 @@ def _get_state_dict( device='cpu', cache_dir: tp.Optional[str] = None, ): + if cache_dir is None: + cache_dir = get_audiocraft_cache_dir() # Return the state dict either from a file or url file_or_url_or_id = str(file_or_url_or_id) assert isinstance(file_or_url_or_id, str) @@ -72,21 +80,120 @@ def _get_state_dict( return torch.load(file, map_location=device) else: - raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.") + assert filename is not None, "filename needs to be defined if using HF checkpoints" + + file = hf_hub_download( + repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, + library_name="audiocraft", library_version=audiocraft.__version__) + return torch.load(file, map_location=device) + +def create_melody_config(model_id: str, device: str) -> DictConfig: + """Create a fallback configuration for melody models. + + Args: + model_id: The model identifier + device: The device to use + + Returns: + A compatible OmegaConf DictConfig + """ + base_cfg = { + "device": str(device), + "channels": 2 if "stereo" in model_id else 1, + "sample_rate": 32000, + "audio_channels": 2 if "stereo" in model_id else 1, + "frame_rate": 50, + "codec_name": "encodec", + "codec": { + "dim": 128, + "hidden_dim": 1024, + "stride": 320, + "n_q": 4, + "codebook_size": 2048, + "normalize": True, + } + } + return OmegaConf.create(base_cfg) + +def create_default_config(model_id: str, device: str) -> DictConfig: + """Create a fallback configuration for standard models. + + Args: + model_id: The model identifier + device: The device to use + + Returns: + A compatible OmegaConf DictConfig + """ + base_cfg = { + "device": str(device), + "channels": 2 if "stereo" in model_id else 1, + "sample_rate": 32000, + "audio_channels": 2 if "stereo" in model_id else 1, + "frame_rate": 50, + "codec_name": "encodec", + "codec": { + "dim": 128, + "hidden_dim": 1024, + "stride": 320, + "n_q": 4, + "codebook_size": 1024, + "normalize": True, + } + } + return OmegaConf.create(base_cfg) + + +def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): + return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): - pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) - cfg = OmegaConf.create(pkg['xp.cfg']) + pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) + if 'pretrained' in pkg: + return CompressionModel.get_pretrained(pkg['pretrained'], device=device) + + # Handle newer model formats that might not have xp.cfg + if 'xp.cfg' not in pkg: + if file_or_url_or_id in ['melody-large', 'stereo-melody', 'stereo-medium', + 'stereo-small', 'stereo-large', 'stereo-melody-large']: + print(f"Using fallback configuration for {file_or_url_or_id}") + # Create a default configuration based on the model type + # This is where you'd need to add model-specific configurations + if 'melody' in file_or_url_or_id: + cfg = create_melody_config(file_or_url_or_id, device) + else: + cfg = create_default_config(file_or_url_or_id, device) + else: + raise KeyError(f"Missing configuration for model {file_or_url_or_id}") + else: + cfg = OmegaConf.create(pkg['xp.cfg']) + cfg.device = str(device) model = builders.get_compression_model(cfg) model.load_state_dict(pkg['best_state']) model.eval() return model +def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): + return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) + + +def _delete_param(cfg: DictConfig, full_name: str): + parts = full_name.split('.') + for part in parts[:-1]: + if part in cfg: + cfg = cfg[part] + else: + return + OmegaConf.set_struct(cfg, False) + if parts[-1] in cfg: + del cfg[parts[-1]] + OmegaConf.set_struct(cfg, True) + def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): - pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) + pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) cfg = OmegaConf.create(pkg['xp.cfg']) cfg.device = str(device) if cfg.device == 'cpu': @@ -95,8 +202,42 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di cfg.dtype = 'float32' else: cfg.dtype = 'float16' + _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path') + _delete_param(cfg, 'conditioners.args.merge_text_conditions_p') + _delete_param(cfg, 'conditioners.args.drop_desc_p') model = builders.get_lm_model(cfg) model.load_state_dict(pkg['best_state']) model.eval() model.cfg = cfg return model + + +def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], + filename: tp.Optional[str] = None, + cache_dir: tp.Optional[str] = None): + return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir) + + +def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], + device='cpu', + filename: tp.Optional[str] = None, + cache_dir: tp.Optional[str] = None): + pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir) + models = [] + processors = [] + cfgs = [] + sample_rate = pkg['sample_rate'] + for i in range(pkg['n_bands']): + cfg = pkg[i]['cfg'] + model = builders.get_diffusion_model(cfg) + model_dict = pkg[i]['model_state'] + model.load_state_dict(model_dict) + model.to(device) + processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate) + processor_dict = pkg[i]['processor_state'] + processor.load_state_dict(processor_dict) + processor.to(device) + models.append(model) + processors.append(processor) + cfgs.append(cfg) + return models, processors, cfgs \ No newline at end of file diff --git a/audiocraft/models/musicgen.py b/audiocraft/models/musicgen.py index bb53b71784856b0bc3dfea798e855712ba01ba9b..29bb57e27b96b3e1f74847fbf90c0ab09c3601c6 100644 --- a/audiocraft/models/musicgen.py +++ b/audiocraft/models/musicgen.py @@ -11,18 +11,19 @@ and provide easy access to the generation API. import os import typing as tp +import warnings +import omegaconf import torch from .encodec import CompressionModel from .lm import LMModel -from .builders import get_debug_compression_model, get_debug_lm_model +from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP from ..data.audio_utils import convert_audio from ..modules.conditioners import ConditioningAttributes, WavCondition from ..utils.autocast import TorchAutocast - MelodyList = tp.List[tp.Optional[torch.Tensor]] MelodyType = tp.Union[torch.Tensor, MelodyList] @@ -35,11 +36,32 @@ class MusicGen: compression_model (CompressionModel): Compression model used to map audio to invertible discrete representations. lm (LMModel): Language model over discrete representations. + max_duration (float, optional): maximum duration the model can produce, + otherwise, inferred from the training params. """ - def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: float = 30): + def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: tp.Optional[float] = 30): self.name = name self.compression_model = compression_model self.lm = lm + self.cfg: tp.Optional[omegaconf.DictConfig] = None + # Just to be safe, let's put everything in eval mode. + self.compression_model.eval() + self.lm.eval() + + if hasattr(lm, 'cfg'): + cfg = lm.cfg + assert isinstance(cfg, omegaconf.DictConfig) + self.cfg = cfg + + if self.cfg is not None: + self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg) + + if max_duration is None: + if self.cfg is not None: + max_duration = lm.cfg.dataset.segment_duration # type: ignore + else: + raise ValueError("You must provide max_duration when building directly MusicGen") + assert max_duration is not None self.max_duration = max_duration self.duration = 15.0 # default duration self.device = next(iter(lm.parameters())).device @@ -53,7 +75,12 @@ class MusicGen: enabled=True, device_type=self.device.type, dtype=torch.float16) @property - def frame_rate(self) -> int: + def version(self) -> str: + from audiocraft import __version__ as audiocraft_version + return audiocraft_version + + @property + def frame_rate(self) -> float: """Roughly the number of AR steps per seconds.""" return self.compression_model.frame_rate @@ -100,12 +127,15 @@ class MusicGen: f"{name} is not a valid checkpoint name. " f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}" ) + else: + name = HF_MODEL_CHECKPOINTS_MAP[name] cache_dir = os.environ.get('MUSICGEN_ROOT', None) compression_model = load_compression_model(name, device=device, cache_dir=cache_dir) lm = load_lm_model(name, device=device, cache_dir=cache_dir) - if name == 'melody': + if name.__contains__('melody') or 'self_wav' in lm.condition_provider.conditioners: lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True + lm.condition_provider.conditioners['self_wav']._use_masking = False return MusicGen(name, compression_model, lm) @@ -125,6 +155,9 @@ class MusicGen: two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, instead of batching together the two. This has some impact on how things are padded but seems to have little impact in practice. + extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much + should we extend the audio each time. Larger values will mean less context is + preserved, and shorter value will require extra computations. rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented. """ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." @@ -137,47 +170,61 @@ class MusicGen: 'top_k': top_k, 'top_p': top_p, 'cfg_coef': cfg_coef, - 'two_step_cfg': two_step_cfg, + 'two_step_cfg': two_step_cfg, } def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): """Override the default progress callback.""" self._progress_callback = progress_callback - def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor: + def generate_unconditional(self, num_samples: int, progress: bool = False, + return_tokens: bool = False) -> tp.Union[torch.Tensor, + tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples in an unconditional manner. Args: num_samples (int): Number of samples to be generated. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False. """ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) - return self._generate_tokens(attributes, prompt_tokens, progress) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) - def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor: + def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text. Args: - descriptions (tp.List[str]): A list of strings used as text conditioning. + descriptions (list of str): A list of strings used as text conditioning. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False. """ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) assert prompt_tokens is None - return self._generate_tokens(attributes, prompt_tokens, progress) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType, - melody_sample_rate: int, progress: bool = False) -> torch.Tensor: + melody_sample_rate: int, progress: bool = False, + return_tokens: bool = False) -> tp.Union[torch.Tensor, + tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text and melody. Args: - descriptions (tp.List[str]): A list of strings used as text conditioning. + descriptions (list of str): A list of strings used as text conditioning. melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as melody conditioning. Should have shape [B, C, T] with B matching the description length, C=1 or 2. It can be [C, T] if there is a single description. It can also be a list of [C, T] tensors. melody_sample_rate: (int): Sample rate of the melody waveforms. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False. """ if isinstance(melody_wavs, torch.Tensor): if melody_wavs.dim() == 2: @@ -197,10 +244,14 @@ class MusicGen: attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None, melody_wavs=melody_wavs) assert prompt_tokens is None - return self._generate_tokens(attributes, prompt_tokens, progress) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType, - sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None, return_tokens: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text and melody and audio prompts. Args: descriptions (tp.List[str]): A list of strings used as text conditioning. @@ -249,19 +300,24 @@ class MusicGen: assert prompt_tokens is not None else: assert prompt_tokens is None - return self._generate_tokens(attributes, prompt_tokens, progress) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, - progress: bool = False) -> torch.Tensor: + progress: bool = False, return_tokens: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on audio prompts. Args: prompt (torch.Tensor): A batch of waveforms used for continuation. Prompt should be [B, C, T], or [C, T] if only one sample is generated. prompt_sample_rate (int): Sampling rate of the given audio waveforms. - descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None. + descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False. """ if prompt.dim() == 2: prompt = prompt[None] @@ -272,7 +328,10 @@ class MusicGen: descriptions = [None] * len(prompt) attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) assert prompt_tokens is not None - return self._generate_tokens(attributes, prompt_tokens, progress) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) @torch.no_grad() def _prepare_tokens_and_attributes( @@ -284,9 +343,9 @@ class MusicGen: """Prepare model inputs. Args: - descriptions (tp.List[str]): A list of strings used as text conditioning. + descriptions (list of str): A list of strings used as text conditioning. prompt (torch.Tensor): A batch of waveforms used for continuation. - melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms + melody_wavs (torch.Tensor, optional): A batch of waveforms used as melody conditioning. Defaults to None. """ attributes = [ @@ -296,11 +355,12 @@ class MusicGen: if melody_wavs is None: for attr in attributes: attr.wav['self_wav'] = WavCondition( - torch.zeros((1, 1), device=self.device), + torch.zeros((1, 1, 1), device=self.device), torch.tensor([0], device=self.device), - path='null_wav') # type: ignore + sample_rate=[self.sample_rate], + path=[None]) # type: ignore else: - if self.name != "melody": + if 'self_wav' not in self.lm.condition_provider.conditioners: raise RuntimeError("This model doesn't support melody conditioning. " "Use the `melody` model.") assert len(melody_wavs) == len(descriptions), \ @@ -309,13 +369,17 @@ class MusicGen: for attr, melody in zip(attributes, melody_wavs): if melody is None: attr.wav['self_wav'] = WavCondition( - torch.zeros((1, 1), device=self.device), + torch.zeros((1, 1, 1), device=self.device), torch.tensor([0], device=self.device), - path='null_wav') # type: ignore + sample_rate=[self.sample_rate], + path=[None]) # type: ignore else: attr.wav['self_wav'] = WavCondition( - melody.to(device=self.device), - torch.tensor([melody.shape[-1]], device=self.device)) + melody[None].to(device=self.device), + torch.tensor([melody.shape[-1]], device=self.device), + sample_rate=[self.sample_rate], + path=[None], + ) if prompt is not None: if descriptions is not None: @@ -396,8 +460,10 @@ class MusicGen: positions = torch.arange(initial_position, initial_position + wav_target_length, device=self.device) attr.wav['self_wav'] = WavCondition( - ref_wav[0][:, positions % wav_length], - torch.full_like(ref_wav[1], wav_target_length)) + ref_wav[0][..., positions % wav_length], + torch.full_like(ref_wav[1], wav_target_length), + [self.sample_rate] * ref_wav[0].size(0), + [None], [0.]) with self.autocast: gen_tokens = self.lm.generate( prompt_tokens, attributes, @@ -411,13 +477,21 @@ class MusicGen: current_gen_offset += stride_tokens gen_tokens = torch.cat(all_tokens, dim=-1) + return gen_tokens # generate audio - assert gen_tokens.dim() == 3 - with torch.no_grad(): - gen_audio = self.compression_model.decode(gen_tokens, None) - return gen_audio + def generate_audio(self, gen_tokens: torch.Tensor): + try: + """Generate Audio from tokens""" + assert gen_tokens.dim() == 3 + with torch.no_grad(): + gen_audio = self.compression_model.decode(gen_tokens, None) + return gen_audio + except Exception as e: + print(f"Error generating audio: {e}") + return None + #def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], # prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: # """Generate discrete audio tokens given audio prompt and/or conditions. diff --git a/audiocraft/models/unet.py b/audiocraft/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..db4a6df8e309c21fede37abdbe3c862932027641 --- /dev/null +++ b/audiocraft/models/unet.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Pytorch Unet Module used for diffusion. +""" + +from dataclasses import dataclass +import typing as tp + +import torch +from torch import nn +from torch.nn import functional as F +from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding + + +@dataclass +class Output: + sample: torch.Tensor + + +def get_model(cfg, channels: int, side: int, num_steps: int): + if cfg.model == 'unet': + return DiffusionUnet( + chin=channels, num_steps=num_steps, **cfg.diffusion_unet) + else: + raise RuntimeError('Not Implemented') + + +class ResBlock(nn.Module): + def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, + dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + dropout: float = 0.): + super().__init__() + stride = 1 + padding = dilation * (kernel - stride) // 2 + Conv = nn.Conv1d + Drop = nn.Dropout1d + self.norm1 = nn.GroupNorm(norm_groups, channels) + self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) + self.activation1 = activation() + self.dropout1 = Drop(dropout) + + self.norm2 = nn.GroupNorm(norm_groups, channels) + self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) + self.activation2 = activation() + self.dropout2 = Drop(dropout) + + def forward(self, x): + h = self.dropout1(self.conv1(self.activation1(self.norm1(x)))) + h = self.dropout2(self.conv2(self.activation2(self.norm2(h)))) + return x + h + + +class DecoderLayer(nn.Module): + def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, + norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + dropout: float = 0.): + super().__init__() + padding = (kernel - stride) // 2 + self.res_blocks = nn.Sequential( + *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) + for idx in range(res_blocks)]) + self.norm = nn.GroupNorm(norm_groups, chin) + ConvTr = nn.ConvTranspose1d + self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False) + self.activation = activation() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.res_blocks(x) + x = self.norm(x) + x = self.activation(x) + x = self.convtr(x) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, + norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + dropout: float = 0.): + super().__init__() + padding = (kernel - stride) // 2 + Conv = nn.Conv1d + self.conv = Conv(chin, chout, kernel, stride, padding, bias=False) + self.norm = nn.GroupNorm(norm_groups, chout) + self.activation = activation() + self.res_blocks = nn.Sequential( + *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) + for idx in range(res_blocks)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, T = x.shape + stride, = self.conv.stride + pad = (stride - (T % stride)) % stride + x = F.pad(x, (0, pad)) + + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + x = self.res_blocks(x) + return x + + +class BLSTM(nn.Module): + """BiLSTM with same hidden units as input dim. + """ + def __init__(self, dim, layers=2): + super().__init__() + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x): + x = x.permute(2, 0, 1) + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + return x + + +class DiffusionUnet(nn.Module): + def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2., + max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False, + bilstm: bool = False, transformer: bool = False, + codec_dim: tp.Optional[int] = None, **kwargs): + super().__init__() + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.embeddings: tp.Optional[nn.ModuleList] = None + self.embedding = nn.Embedding(num_steps, hidden) + if emb_all_layers: + self.embeddings = nn.ModuleList() + self.condition_embedding: tp.Optional[nn.Module] = None + for d in range(depth): + encoder = EncoderLayer(chin, hidden, **kwargs) + decoder = DecoderLayer(hidden, chin, **kwargs) + self.encoders.append(encoder) + self.decoders.insert(0, decoder) + if emb_all_layers and d > 0: + assert self.embeddings is not None + self.embeddings.append(nn.Embedding(num_steps, hidden)) + chin = hidden + hidden = min(int(chin * growth), max_channels) + self.bilstm: tp.Optional[nn.Module] + if bilstm: + self.bilstm = BLSTM(chin) + else: + self.bilstm = None + self.use_transformer = transformer + self.cross_attention = False + if transformer: + self.cross_attention = cross_attention + self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False, + cross_attention=cross_attention) + + self.use_codec = False + if codec_dim is not None: + self.conv_codec = nn.Conv1d(codec_dim, chin, 1) + self.use_codec = True + + def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None): + skips = [] + bs = x.size(0) + z = x + view_args = [1] + if type(step) is torch.Tensor: + step_tensor = step + else: + step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs) + + for idx, encoder in enumerate(self.encoders): + z = encoder(z) + if idx == 0: + z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z) + elif self.embeddings is not None: + z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z) + + skips.append(z) + + if self.use_codec: # insert condition in the bottleneck + assert condition is not None, "Model defined for conditionnal generation" + condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim + assert condition_emb.size(-1) <= 2 * z.size(-1), \ + f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}" + if not self.cross_attention: + + condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1)) + assert z.size() == condition_emb.size() + z += condition_emb + cross_attention_src = None + else: + cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C + B, T, C = cross_attention_src.shape + positions = torch.arange(T, device=x.device).view(1, -1, 1) + pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype) + cross_attention_src = cross_attention_src + pos_emb + if self.use_transformer: + z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1) + else: + if self.bilstm is None: + z = torch.zeros_like(z) + else: + z = self.bilstm(z) + + for decoder in self.decoders: + s = skips.pop(-1) + z = z[:, :, :s.shape[2]] + z = z + s + z = decoder(z) + + z = z[:, :, :x.shape[2]] + return Output(z) diff --git a/audiocraft/modules/__init__.py b/audiocraft/modules/__init__.py index 81ba30f6466ff91b90490a4fb92f7d3d0d00144d..6d732afcb6baa4e88f7fb075757982273e7e57e3 100644 --- a/audiocraft/modules/__init__.py +++ b/audiocraft/modules/__init__.py @@ -18,3 +18,4 @@ from .conv import ( ) from .lstm import StreamableLSTM from .seanet import SEANetEncoder, SEANetDecoder +from .transformer import StreamingTransformer \ No newline at end of file diff --git a/audiocraft/modules/chroma.py b/audiocraft/modules/chroma.py new file mode 100644 index 0000000000000000000000000000000000000000..e84fb66b4a4aaefb0b3ccac8a9a44c3b20e48f61 --- /dev/null +++ b/audiocraft/modules/chroma.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import typing as tp + +from einops import rearrange +from librosa import filters +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio + + +class ChromaExtractor(nn.Module): + """Chroma extraction and quantization. + + Args: + sample_rate (int): Sample rate for the chroma extraction. + n_chroma (int): Number of chroma bins for the chroma extraction. + radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). + nfft (int, optional): Number of FFT. + winlen (int, optional): Window length. + winhop (int, optional): Window hop size. + argmax (bool, optional): Whether to use argmax. Defaults to False. + norm (float, optional): Norm for chroma normalization. Defaults to inf. + """ + def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, + winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, + norm: float = torch.inf): + super().__init__() + self.winlen = winlen or 2 ** radix2_exp + self.nfft = nfft or self.winlen + self.winhop = winhop or (self.winlen // 4) + self.sample_rate = sample_rate + self.n_chroma = n_chroma + self.norm = norm + self.argmax = argmax + self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, + n_chroma=self.n_chroma)), persistent=False) + self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, + hop_length=self.winhop, power=2, center=True, + pad=0, normalized=True) + + def forward(self, wav: torch.Tensor) -> torch.Tensor: + T = wav.shape[-1] + # in case we are getting a wav that was dropped out (nullified) + # from the conditioner, make sure wav length is no less that nfft + if T < self.nfft: + pad = self.nfft - T + r = 0 if pad % 2 == 0 else 1 + wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) + assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" + + spec = self.spec(wav).squeeze(1) + raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) + norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) + norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') + + if self.argmax: + idx = norm_chroma.argmax(-1, keepdim=True) + norm_chroma[:] = 0 + norm_chroma.scatter_(dim=-1, index=idx, value=1) + + return norm_chroma diff --git a/audiocraft/modules/codebooks_patterns.py b/audiocraft/modules/codebooks_patterns.py index c5b35cbea8cff84aa56116dbdd860fc72a913a13..61362588403a3eef4a4b1b4ad4595526722da20f 100644 --- a/audiocraft/modules/codebooks_patterns.py +++ b/audiocraft/modules/codebooks_patterns.py @@ -122,7 +122,7 @@ class Pattern: Args: timesteps (int): Maximum number of timesteps steps to consider. keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. - device (Union[torch.device, str]): Device for created tensors. + device (torch.device or str): Device for created tensors. Returns: indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. @@ -189,9 +189,9 @@ class Pattern: keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. Steps that are beyond valid steps will be replaced by the special_token in that case. is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. - device (Union[torch.device, str]): Device for created tensors. + device (torch.device or str): Device for created tensors. Returns: - torch.Tensor: Indexes for reconstructing the output, of shape [K, T]. + indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T]. mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. """ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout @@ -295,7 +295,7 @@ class CodebooksPatternProvider(ABC): """Builds pattern with specific interleaving between codebooks. Args: - timesteps (int): Total numer of timesteps. + timesteps (int): Total number of timesteps. """ raise NotImplementedError() @@ -318,7 +318,7 @@ class DelayedPatternProvider(CodebooksPatternProvider): Args: n_q (int): Number of codebooks. - delays (Optional[List[int]]): Delay for each of the codebooks. + delays (list of int, optional): Delay for each of the codebooks. If delays not defined, each codebook is delayed by 1 compared to the previous one. flatten_first (int): Flatten the first N timesteps. empty_initial (int): Prepend with N empty list of coordinates. @@ -406,10 +406,10 @@ class UnrolledPatternProvider(CodebooksPatternProvider): Args: n_q (int): Number of codebooks. - flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined, + flattening (list of int, optional): Flattening schema over the codebooks. If not defined, the codebooks will be flattened to 1 codebook per step, meaning that the sequence will have n_q extra steps for each timestep. - delays (Optional[List[int]]): Delay for each of the codebooks. If not defined, + delays (list of int, optional): Delay for each of the codebooks. If not defined, no delay is added and therefore will default to [0] * ``n_q``. Note that two codebooks that will be flattened to the same inner step should have the same delay, otherwise the pattern is considered as invalid. @@ -462,7 +462,7 @@ class UnrolledPatternProvider(CodebooksPatternProvider): """Builds pattern for delay across codebooks. Args: - timesteps (int): Total numer of timesteps. + timesteps (int): Total number of timesteps. """ # the PatternLayout is built as a tuple of sequence position and list of coordinates # so that it can be reordered properly given the required delay between codebooks of given timesteps @@ -486,13 +486,18 @@ class UnrolledPatternProvider(CodebooksPatternProvider): return Pattern(out, n_q=self.n_q, timesteps=timesteps) -class VALLEPattern(CodebooksPatternProvider): - """Almost VALL-E style pattern. We futher allow some delays for the - codebooks other than the first one. +class CoarseFirstPattern(CodebooksPatternProvider): + """First generates all the codebooks #1 (e.g. coarser), then the remaining ones, + potentially with delays. + + ..Warning:: You must always generate the full training duration at test time, for instance, + 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected + location. This is due to the non causality of the remaining codebooks with respect to + the first ones. Args: n_q (int): Number of codebooks. - delays (Optional[List[int]]): Delay for each of the codebooks. + delays (list of int, optional): Delay for each of the codebooks. If delays not defined, each codebook is delayed by 1 compared to the previous one. """ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): diff --git a/audiocraft/modules/conditioners.py b/audiocraft/modules/conditioners.py index 82792316024b88d4c5c38b0a28f443627771d509..178957d1771dc4c6f2df028fd9bb60f204567955 100644 --- a/audiocraft/modules/conditioners.py +++ b/audiocraft/modules/conditioners.py @@ -10,87 +10,61 @@ from dataclasses import dataclass, field from itertools import chain import logging import math +from pathlib import Path import random import re import typing as tp import warnings -from einops import rearrange +import einops from num2words import num2words import spacy -from transformers import T5EncoderModel, T5Tokenizer # type: ignore -import torchaudio +from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore import torch from torch import nn -from torch import Tensor import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence +from .chroma import ChromaExtractor from .streaming import StreamingModule from .transformer import create_sin_embedding +from ..data.audio import audio_read from ..data.audio_dataset import SegmentInfo +from ..data.audio_utils import convert_audio +from ..environment import AudioCraftEnvironment +from ..quantization import ResidualVectorQuantizer from ..utils.autocast import TorchAutocast -from ..utils.utils import hash_trick, length_to_mask, collate +from ..utils.cache import EmbeddingCache +from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once logger = logging.getLogger(__name__) TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist) -ConditionType = tp.Tuple[Tensor, Tensor] # condition, mask +ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask class WavCondition(tp.NamedTuple): - wav: Tensor - length: Tensor + wav: torch.Tensor + length: torch.Tensor + sample_rate: tp.List[int] path: tp.List[tp.Optional[str]] = [] + seek_time: tp.List[tp.Optional[float]] = [] -def nullify_condition(condition: ConditionType, dim: int = 1): - """This function transforms an input condition to a null condition. - The way it is done by converting it to a single zero vector similarly - to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. - - Args: - condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor]) - dim (int): the dimension that will be truncated (should be the time dimension) - WARNING!: dim should not be the batch dimension! - Returns: - ConditionType: a tuple of null condition and mask - """ - assert dim != 0, "dim cannot be the batch dimension!" - assert type(condition) == tuple and \ - type(condition[0]) == Tensor and \ - type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!" - cond, mask = condition - B = cond.shape[0] - last_dim = cond.dim() - 1 - out = cond.transpose(dim, last_dim) - out = 0. * out[..., :1] - out = out.transpose(dim, last_dim) - mask = torch.zeros((B, 1), device=out.device).int() - assert cond.dim() == out.dim() - return out, mask - - -def nullify_wav(wav: Tensor) -> WavCondition: - """Create a nullified WavCondition from a wav tensor with appropriate shape. - - Args: - wav (Tensor): tensor of shape [B, T] - Returns: - WavCondition: wav condition with nullified wav. - """ - null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1) - return WavCondition( - wav=null_wav, - length=torch.tensor([0] * wav.shape[0], device=wav.device), - path=['null_wav'] * wav.shape[0] - ) +class JointEmbedCondition(tp.NamedTuple): + wav: torch.Tensor + text: tp.List[tp.Optional[str]] + length: torch.Tensor + sample_rate: tp.List[int] + path: tp.List[tp.Optional[str]] = [] + seek_time: tp.List[tp.Optional[float]] = [] @dataclass class ConditioningAttributes: text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) wav: tp.Dict[str, WavCondition] = field(default_factory=dict) + joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) def __getitem__(self, item): return getattr(self, item) @@ -103,14 +77,23 @@ class ConditioningAttributes: def wav_attributes(self): return self.wav.keys() + @property + def joint_embed_attributes(self): + return self.joint_embed.keys() + @property def attributes(self): - return {"text": self.text_attributes, "wav": self.wav_attributes} + return { + "text": self.text_attributes, + "wav": self.wav_attributes, + "joint_embed": self.joint_embed_attributes, + } def to_flat_dict(self): return { **{f"text.{k}": v for k, v in self.text.items()}, **{f"wav.{k}": v for k, v in self.wav.items()}, + **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} } @classmethod @@ -131,11 +114,74 @@ class SegmentWithAttributes(SegmentInfo): raise NotImplementedError() +def nullify_condition(condition: ConditionType, dim: int = 1): + """Transform an input condition to a null condition. + The way it is done by converting it to a single zero vector similarly + to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. + + Args: + condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor]) + dim (int): The dimension that will be truncated (should be the time dimension) + WARNING!: dim should not be the batch dimension! + Returns: + ConditionType: A tuple of null condition and mask + """ + assert dim != 0, "dim cannot be the batch dimension!" + assert isinstance(condition, tuple) and \ + isinstance(condition[0], torch.Tensor) and \ + isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!" + cond, mask = condition + B = cond.shape[0] + last_dim = cond.dim() - 1 + out = cond.transpose(dim, last_dim) + out = 0. * out[..., :1] + out = out.transpose(dim, last_dim) + mask = torch.zeros((B, 1), device=out.device).int() + assert cond.dim() == out.dim() + return out, mask + + +def nullify_wav(cond: WavCondition) -> WavCondition: + """Transform a WavCondition to a nullified WavCondition. + It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. + + Args: + cond (WavCondition): Wav condition with wav, tensor of shape [B, T]. + Returns: + WavCondition: Nullified wav condition. + """ + null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) + return WavCondition( + wav=null_wav, + length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device), + sample_rate=cond.sample_rate, + path=[None] * cond.wav.shape[0], + seek_time=[None] * cond.wav.shape[0], + ) + + +def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: + """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0, + and replacing metadata by dummy attributes. + + Args: + cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T]. + """ + null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1) + return JointEmbedCondition( + wav=null_wav, text=[None] * len(embed.text), + length=torch.LongTensor([0]).to(embed.wav.device), + sample_rate=embed.sample_rate, + path=[None] * embed.wav.shape[0], + seek_time=[0] * embed.wav.shape[0], + ) + + class Tokenizer: - """Base class for all tokenizers + """Base tokenizer implementation (in case we want to introduce more advances tokenizers in the future). """ - def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]: + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError() @@ -146,7 +192,7 @@ class WhiteSpaceTokenizer(Tokenizer): [[78, 62, 31, 4, 78, 25, 19, 34], [59, 77, 0, 0, 0, 0, 0, 0]] """ - PUNCTUATIONS = "?:!.,;" + PUNCTUATION = "?:!.,;" def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm", lemma: bool = True, stopwords: bool = True) -> None: @@ -161,18 +207,15 @@ class WhiteSpaceTokenizer(Tokenizer): self.nlp = spacy.load(language) @tp.no_type_check - def __call__( - self, - texts: tp.List[tp.Optional[str]], - return_text: bool = False - ) -> tp.Tuple[Tensor, Tensor]: + def __call__(self, texts: tp.List[tp.Optional[str]], + return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Take a list of strings and convert them to a tensor of indices. Args: - texts (tp.List[str]): List of strings. + texts (list[str]): List of strings. return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False. Returns: - tp.Tuple[Tensor, Tensor]: + tuple[torch.Tensor, torch.Tensor]: - Indices of words in the LUT. - And a mask indicating where the padding tokens are """ @@ -181,7 +224,7 @@ class WhiteSpaceTokenizer(Tokenizer): for i, text in enumerate(texts): # if current sample doesn't have a certain attribute, replace with pad token if text is None: - output.append(Tensor([self.pad_idx])) + output.append(torch.Tensor([self.pad_idx])) lengths.append(0) continue @@ -192,15 +235,15 @@ class WhiteSpaceTokenizer(Tokenizer): # remove stopwords if self.stopwords: text = [w for w in text if not w.is_stop] # type: ignore - # remove punctuations - text = [w for w in text if w.text not in self.PUNCTUATIONS] # type: ignore + # remove punctuation + text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore # lemmatize if needed text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore texts[i] = " ".join(text) lengths.append(len(text)) # convert to tensor - tokens = Tensor([hash_trick(w, self.n_bins) for w in text]) + tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text]) output.append(tokens) mask = length_to_mask(torch.IntTensor(lengths)).int() @@ -224,7 +267,7 @@ class NoopTokenizer(Tokenizer): self.n_bins = n_bins self.pad_idx = pad_idx - def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]: + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: output, lengths = [], [] for text in texts: # if current sample doesn't have a certain attribute, replace with pad token @@ -241,15 +284,16 @@ class NoopTokenizer(Tokenizer): class BaseConditioner(nn.Module): - """Base model for all conditioner modules. We allow the output dim to be different - than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large; + """Base model for all conditioner modules. + We allow the output dim to be different than the hidden dim for two reasons: + 1) keep our LUTs small when the vocab is large; 2) make all condition dims consistent. Args: - dim (int): Hidden dim of the model (text-encoder/LUT). + dim (int): Hidden dim of the model. output_dim (int): Output dim of the conditioner. """ - def __init__(self, dim, output_dim): + def __init__(self, dim: int, output_dim: int): super().__init__() self.dim = dim self.output_dim = output_dim @@ -294,9 +338,9 @@ class LUTConditioner(TextConditioner): super().__init__(dim, output_dim) self.embed = nn.Embedding(n_bins, dim) self.tokenizer: Tokenizer - if tokenizer == "whitespace": + if tokenizer == 'whitespace': self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx) - elif tokenizer == "noop": + elif tokenizer == 'noop': self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx) else: raise ValueError(f"unrecognized tokenizer `{tokenizer}`.") @@ -346,13 +390,12 @@ class T5Conditioner(TextConditioner): def __init__(self, name: str, output_dim: int, finetune: bool, device: str, autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., normalize_text: bool = False): - assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})" + assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" super().__init__(self.MODELS_DIMS[name], output_dim) self.device = device self.name = name self.finetune = finetune self.word_dropout = word_dropout - if autocast_dtype is None or self.device == 'cpu': self.autocast = TorchAutocast(enabled=False) if self.device != 'cpu': @@ -378,7 +421,7 @@ class T5Conditioner(TextConditioner): else: # this makes sure that the t5 models is not part # of the saved checkpoint - self.__dict__["t5"] = t5.to(device) + self.__dict__['t5'] = t5.to(device) self.normalize_text = normalize_text if normalize_text: @@ -398,13 +441,13 @@ class T5Conditioner(TextConditioner): empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) - inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device) - mask = inputs["attention_mask"] + inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) + mask = inputs['attention_mask'] mask[empty_idx, :] = 0 # zero-out index where the input is non-existant return inputs def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: - mask = inputs["attention_mask"] + mask = inputs['attention_mask'] with torch.set_grad_enabled(self.finetune), self.autocast: embeds = self.t5(**inputs).last_hidden_state embeds = self.output_proj(embeds.to(self.output_proj.weight)) @@ -426,204 +469,558 @@ class WaveformConditioner(BaseConditioner): def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): super().__init__(dim, output_dim) self.device = device + # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample. + self._use_masking = True - def tokenize(self, wav_length: WavCondition) -> WavCondition: - wav, length, path = wav_length + def tokenize(self, x: WavCondition) -> WavCondition: + wav, length, sample_rate, path, seek_time = x assert length is not None - return WavCondition(wav.to(self.device), length.to(self.device), path) + return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time) - def _get_wav_embedding(self, wav: Tensor) -> Tensor: - """Gets as input a wav and returns a dense vector of conditions.""" + def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: + """Gets as input a WavCondition and returns a dense embedding.""" raise NotImplementedError() def _downsampling_factor(self): """Returns the downsampling factor of the embedding model.""" raise NotImplementedError() - def forward(self, inputs: WavCondition) -> ConditionType: - """ + def forward(self, x: WavCondition) -> ConditionType: + """Extract condition embedding and mask from a waveform and its metadata. Args: - input (WavCondition): Tuple of (waveform, lengths). + x (WavCondition): Waveform condition containing raw waveform and metadata. Returns: - ConditionType: Dense vector representing the conditioning along with its' mask. + ConditionType: a dense vector representing the conditioning along with its mask """ - wav, lengths, path = inputs + wav, lengths, *_ = x with torch.no_grad(): - embeds = self._get_wav_embedding(wav) + embeds = self._get_wav_embedding(x) embeds = embeds.to(self.output_proj.weight) embeds = self.output_proj(embeds) - if lengths is not None: + if lengths is not None and self._use_masking: lengths = lengths / self._downsampling_factor() mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore else: - mask = torch.ones_like(embeds) - embeds = (embeds * mask.unsqueeze(2).to(self.device)) - + mask = torch.ones_like(embeds[..., 0]) + embeds = (embeds * mask.unsqueeze(-1)) return embeds, mask class ChromaStemConditioner(WaveformConditioner): - """Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by - the insight the drums and bass often dominate the chroma, leading to the chroma not containing the - information about melody. + """Chroma conditioner based on stems. + The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as + the drums and bass often dominate the chroma leading to the chroma features + not containing information about the melody. Args: output_dim (int): Output dimension for the conditioner. sample_rate (int): Sample rate for the chroma extractor. - n_chroma (int): Number of chroma for the chroma extractor. - radix2_exp (int): Radix2 exponent for the chroma extractor. - duration (float): Duration used during training. This is later used for correct padding + n_chroma (int): Number of chroma bins for the chroma extractor. + radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12). + duration (int): duration used during training. This is later used for correct padding in case we are using chroma as prefix. - match_len_on_eval (bool, optional): If True then all chromas are padded to the training + match_len_on_eval (bool, optional): if True then all chromas are padded to the training duration. Defaults to False. - eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as + eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as conditions during eval (for cases where we don't want to leak test conditions like MusicCaps). Defaults to None. - n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0. + n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0. device (tp.Union[torch.device, str], optional): Device for the conditioner. **kwargs: Additional parameters for the chroma extractor. """ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int, duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None, - n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs): + n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None, + device: tp.Union[torch.device, str] = 'cpu', **kwargs): from demucs import pretrained super().__init__(dim=n_chroma, output_dim=output_dim, device=device) - self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32) + self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32) self.sample_rate = sample_rate self.match_len_on_eval = match_len_on_eval + if match_len_on_eval: + self._use_masking = False self.duration = duration - self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device) - self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3} - self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device) - self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp, - device=device, **kwargs) + self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device) + stem_sources: list = self.demucs.sources # type: ignore + self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device) + self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, + radix2_exp=radix2_exp, **kwargs).to(device) self.chroma_len = self._get_chroma_len() - - def _downsampling_factor(self): + self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs) + self.cache = None + if cache_path is not None: + self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, + compute_embed_fn=self._get_full_chroma_for_cache, + extract_embed_fn=self._extract_chroma_chunk) + + def _downsampling_factor(self) -> int: return self.chroma.winhop - def _get_chroma_len(self): - """Get length of chroma during training""" - dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device) + def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]: + """Load pre-defined waveforms from a json. + These waveforms will be used for chroma extraction during evaluation. + This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps). + """ + if path is None: + return None + + logger.info(f"Loading evaluation wavs from {path}") + from audiocraft.data.audio_dataset import AudioDataset + dataset: AudioDataset = AudioDataset.from_meta( + path, segment_duration=self.duration, min_audio_duration=self.duration, + sample_rate=self.sample_rate, channels=1) + + if len(dataset) > 0: + eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device) + logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner") + return eval_wavs + else: + raise ValueError("Could not find evaluation wavs, check lengths of wavs") + + def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None: + self.eval_wavs = eval_wavs + + def has_eval_wavs(self) -> bool: + return self.eval_wavs is not None + + def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor: + """Sample wavs from a predefined list.""" + assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided." + total_eval_wavs = len(self.eval_wavs) + out = self.eval_wavs + if num_samples > total_eval_wavs: + out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1) + return out[torch.randperm(len(out))][:num_samples] + + def _get_chroma_len(self) -> int: + """Get length of chroma during training.""" + dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device) dummy_chr = self.chroma(dummy_wav) return dummy_chr.shape[1] @torch.no_grad() - def _get_filtered_wav(self, wav): + def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Get parts of the wav that holds the melody, extracting the main stems from the wav.""" from demucs.apply import apply_model from demucs.audio import convert_audio with self.autocast: - wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels) + wav = convert_audio( + wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore stems = apply_model(self.demucs, wav, device=self.device) - stems = stems[:, self.stem_idx] # extract stem - stems = stems.sum(1) # merge extracted stems - stems = stems.mean(1, keepdim=True) # mono - stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1) - return stems + stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning + mix_wav = stems.sum(1) # merge extracted stems to single waveform + mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore + return mix_wav @torch.no_grad() - def _get_wav_embedding(self, wav): + def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor: + """Extract chroma features from the waveform.""" + with self.autocast: + return self.chroma(wav) + + @torch.no_grad() + def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Compute wav embedding, applying stem and chroma extraction.""" # avoid 0-size tensors when we are working with null conds if wav.shape[-1] == 1: - return self.chroma(wav) - stems = self._get_filtered_wav(wav) - chroma = self.chroma(stems) + return self._extract_chroma(wav) + stems = self._get_stemmed_wav(wav, sample_rate) + chroma = self._extract_chroma(stems) + return chroma + + @torch.no_grad() + def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor: + """Extract chroma from the whole audio waveform at the given path.""" + wav, sr = audio_read(path) + wav = wav[None].to(self.device) + wav = convert_audio(wav, sr, self.sample_rate, to_channels=1) + chroma = self._compute_wav_embedding(wav, self.sample_rate)[0] + return chroma + + def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor: + """Extract a chunk of chroma from the full chroma derived from the full waveform.""" + wav_length = x.wav.shape[-1] + seek_time = x.seek_time[idx] + assert seek_time is not None, ( + "WavCondition seek_time is required " + "when extracting chroma chunks from pre-computed chroma.") + full_chroma = full_chroma.float() + frame_rate = self.sample_rate / self._downsampling_factor() + target_length = int(frame_rate * wav_length / self.sample_rate) + index = int(frame_rate * seek_time) + out = full_chroma[index: index + target_length] + out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0] + return out.to(self.device) + + @torch.no_grad() + def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: + """Get the wav embedding from the WavCondition. + The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly + or will rely on the embedding cache to load the pre-computed embedding if relevant. + """ + sampled_wav: tp.Optional[torch.Tensor] = None + if not self.training and self.eval_wavs is not None: + warn_once(logger, "Using precomputed evaluation wavs!") + sampled_wav = self._sample_eval_wavs(len(x.wav)) + + no_undefined_paths = all(p is not None for p in x.path) + no_nullified_cond = x.wav.shape[-1] > 1 + if sampled_wav is not None: + chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate) + elif self.cache is not None and no_undefined_paths and no_nullified_cond: + paths = [Path(p) for p in x.path if p is not None] + chroma = self.cache.get_embed_from_cache(paths, x) + else: + assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal." + chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0]) if self.match_len_on_eval: - b, t, c = chroma.shape - if t > self.chroma_len: + B, T, C = chroma.shape + if T > self.chroma_len: chroma = chroma[:, :self.chroma_len] - logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})') - elif t < self.chroma_len: - # chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t)) - n_repeat = int(math.ceil(self.chroma_len / t)) + logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})") + elif T < self.chroma_len: + n_repeat = int(math.ceil(self.chroma_len / T)) chroma = chroma.repeat(1, n_repeat, 1) chroma = chroma[:, :self.chroma_len] - logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})') + logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})") + return chroma + def tokenize(self, x: WavCondition) -> WavCondition: + """Apply WavConditioner tokenization and populate cache if needed.""" + x = super().tokenize(x) + no_undefined_paths = all(p is not None for p in x.path) + if self.cache is not None and no_undefined_paths: + paths = [Path(p) for p in x.path if p is not None] + self.cache.populate_embed_cache(paths, x) + return x -class ChromaExtractor(nn.Module): - """Chroma extraction class, handles chroma extraction and quantization. + +class JointEmbeddingConditioner(BaseConditioner): + """Joint embedding conditioning supporting both audio or text conditioning. Args: - sample_rate (int): Sample rate. - n_chroma (int): Number of chroma to consider. - radix2_exp (int): Radix2 exponent. - nfft (tp.Optional[int], optional): Number of FFT. - winlen (tp.Optional[int], optional): Window length. - winhop (tp.Optional[int], optional): Window hop size. - argmax (bool, optional): Whether to use argmax. Defaults to False. - norm (float, optional): Norm for chroma normalization. Defaults to inf. - device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu. + dim (int): Dimension. + output_dim (int): Output dimension. + device (str): Device. + attribute (str): Attribute used by the conditioner. + autocast_dtype (str): Autocast for the conditioner. + quantize (bool): Whether to quantize the CLAP embedding. + n_q (int): Number of residual quantizers (used if quantize is true). + bins (int): Quantizers' codebooks size (used if quantize is true). + kwargs: Additional parameters for residual vector quantizer. """ - def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, - nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, - argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"): - super().__init__() - from librosa import filters + def __init__(self, dim: int, output_dim: int, device: str, attribute: str, + autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True, + n_q: int = 12, bins: int = 1024, **kwargs): + super().__init__(dim=dim, output_dim=output_dim) self.device = device - self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32) - self.winlen = winlen or 2 ** radix2_exp - self.nfft = nfft or self.winlen - self.winhop = winhop or (self.winlen // 4) - self.sr = sample_rate - self.n_chroma = n_chroma - self.norm = norm - self.argmax = argmax - self.window = torch.hann_window(self.winlen).to(device) - self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, - n_chroma=self.n_chroma)).to(device) - self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, - hop_length=self.winhop, power=2, center=True, - pad=0, normalized=True).to(device) - - def forward(self, wav): + self.attribute = attribute + if autocast_dtype is None or device == 'cpu': + self.autocast = TorchAutocast(enabled=False) + logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.") + else: + dtype = getattr(torch, autocast_dtype) + assert isinstance(dtype, torch.dtype) + logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.") + self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) + # residual vector quantizer to discretize the conditioned embedding + self.quantizer: tp.Optional[ResidualVectorQuantizer] = None + if quantize: + self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs) + + def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Get joint embedding in latent space from the inputs. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding + and corresponding empty indexes. + """ + raise NotImplementedError() + + def forward(self, x: JointEmbedCondition) -> ConditionType: with self.autocast: - T = wav.shape[-1] - # in case we are getting a wav that was dropped out (nullified) - # make sure wav length is no less that nfft - if T < self.nfft: - pad = self.nfft - T - r = 0 if pad % 2 == 0 else 1 - wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) - assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}' - spec = self.spec(wav).squeeze(1) - raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec) - norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) - norm_chroma = rearrange(norm_chroma, "b d t -> b t d") - - if self.argmax: - idx = norm_chroma.argmax(-1, keepdims=True) - norm_chroma[:] = 0 - norm_chroma.scatter_(dim=-1, index=idx, value=1) - - return norm_chroma - - -def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str): + embed, empty_idx = self._get_embed(x) + if self.quantizer is not None: + embed = embed.view(-1, self.dim, 1) + q_res = self.quantizer(embed, frame_rate=1) + out_embed = q_res.x.view(-1, self.dim) + else: + out_embed = embed + out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim) + mask = torch.ones(*out_embed.shape[:2], device=out_embed.device) + mask[empty_idx, :] = 0 # zero-out index where the input is non-existant + out_embed = (out_embed * mask.unsqueeze(-1)) + return out_embed, mask + + def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: + return x + + +class CLAPEmbeddingConditioner(JointEmbeddingConditioner): + """Joint Embedding conditioner based on pre-trained CLAP model. + + This CLAP-based conditioner supports a caching mechanism + over the computed embeddings for faster training. + + Args: + dim (int): Dimension. + output_dim (int): Output dimension. + device (str): Device. + attribute (str): Attribute used by the conditioner. + quantize (bool): Whether to quantize the CLAP embedding. + n_q (int): Number of residual quantizers (used if quantize is true). + bins (int): Quantizers' codebooks size (used if quantize is true). + checkpoint (str): Path to CLAP checkpoint. + model_arch (str): CLAP model architecture. + enable_fusion (bool): Enable fusion for CLAP model. + sample_rate (int): Sample rate used by CLAP model. + max_audio_length (float): Maximum audio length for CLAP model. + audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence. + normalize (bool): Whether to normalize the CLAP embedding. + text_p (float): Probability of using text representation instead of audio at train time. + batch_size (Optional[int]): Batch size for CLAP embedding computation. + autocast_dtype (str): Autocast for the conditioner. + cache_path (Optional[str]): Path for pre-computed embeddings caching. + kwargs: Additional parameters for residual vector quantizer. + """ + def __init__(self, dim: int, output_dim: int, device: str, attribute: str, + quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str, + enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int, + normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None, + autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs): + try: + import laion_clap # type: ignore + except ImportError: + raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'") + warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). " + "Please retrain all models.") + checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint) + clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base') + clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) + load_clap_state_dict(clap_model, checkpoint) + clap_model.eval() + clap_model.to(device) + super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute, + autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins, + **kwargs) + self.checkpoint = checkpoint + self.enable_fusion = enable_fusion + self.model_arch = model_arch + self.clap: laion_clap.CLAP_Module + self.clap_tokenize: RobertaTokenizer + self.clap_sample_rate = sample_rate + self.clap_max_frames = int(self.clap_sample_rate * max_audio_length) + self.clap_stride = int(self.clap_sample_rate * audio_stride) + self.batch_size = batch_size or 1 + self.normalize = normalize + self.text_p = text_p + self.__dict__['clap_tokenize'] = clap_tokenize + self.__dict__['clap'] = clap_model + self.wav_cache, self.text_cache = None, None + if cache_path is not None: + self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, + compute_embed_fn=self._get_wav_embedding_for_cache, + extract_embed_fn=self._extract_wav_embedding_chunk) + self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device, + compute_embed_fn=self._get_text_embedding_for_cache) + + def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: + # we use the default params from CLAP module here as well + return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") + + def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: + """Compute text embedding from CLAP model on a given a batch of text. + + Args: + text (list[str]): List of text for the batch, with B items. + Returns: + torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension. + """ + with torch.no_grad(): + embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) + return embed.view(embed.size(0), 1, embed.size(-1)) + + def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], + x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Get text embedding function for the cache.""" + text = x.text[idx] + text = text if text is not None else "" + return self._compute_text_embedding([text])[0] + + def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor: + """Preprocess wav to expected format by CLAP model. + + Args: + wav (torch.Tensor): Audio wav, of shape [B, C, T]. + length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. + sample_rates (list[int]): Sample rates for each sample in the batch + Returns: + torch.Tensor: Audio wav of shape [B, T]. + """ + assert wav.dim() == 3, "Expecting wav to be [B, C, T]" + if sample_rates is not None: + _wav = [] + for i, audio in enumerate(wav): + sr = sample_rates[i] + audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1) + _wav.append(audio) + wav = torch.stack(_wav, dim=0) + wav = wav.mean(dim=1) + return wav + + def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, + sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor: + """Compute audio wave embedding from CLAP model. + + Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences, + we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and + average the resulting embeddings. + + Args: + wav (torch.Tensor): Audio wav, of shape [B, C, T]. + length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. + sample_rates (list[int]): Sample rates for each sample in the batch. + reduce_mean (bool): Whether to get the average tensor. + Returns: + torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension. + """ + with torch.no_grad(): + wav = self._preprocess_wav(wav, length, sample_rates) + B, T = wav.shape + if T >= self.clap_max_frames: + wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T] + else: + wav = wav.view(-1, 1, T) # [B, F, T] with F=1 + wav = einops.rearrange(wav, 'b f t -> (b f) t') + embed_list = [] + for i in range(0, wav.size(0), self.batch_size): + _wav = wav[i:i+self.batch_size, ...] + _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True) + embed_list.append(_embed) + embed = torch.cat(embed_list, dim=0) + embed = einops.rearrange(embed, '(b f) d -> b f d', b=B) + if reduce_mean: + embed = embed.mean(dim=1, keepdim=True) + return embed # [B, F, D] with F=1 if reduce_mean is True + + def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], + x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Compute audio wave embedding for the cache. + The embedding is computed on a given audio read from file. + + Args: + path (str or Path): Path to the full audio file. + Returns: + torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension. + """ + wav, sr = audio_read(path) # [C, T] + wav = wav.unsqueeze(0).to(self.device) # [1, C, T] + wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device) + embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D] + return embed.squeeze(0) # [F, D] + + def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding. + + Args: + full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D]. + x (JointEmbedCondition): Joint embedding condition for the full batch. + idx (int): Index considered for the given embedding to extract. + Returns: + torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D]. + """ + sample_rate = x.sample_rate[idx] + seek_time = x.seek_time[idx] + seek_time = 0. if seek_time is None else seek_time + clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate + end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate + start_offset = int(seek_time * sample_rate // clap_stride) + end_offset = int(end_seek_time * sample_rate // clap_stride) + wav_embed = full_embed[start_offset:end_offset, ...] + wav_embed = wav_embed.mean(dim=0, keepdim=True) + return wav_embed.to(self.device) # [F, D] + + def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: + """Get CLAP embedding from a batch of text descriptions.""" + no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout + if self.text_cache is not None and no_nullified_cond: + assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + embed = self.text_cache.get_embed_from_cache(paths, x) + else: + text = [xi if xi is not None else "" for xi in x.text] + embed = self._compute_text_embedding(text) + if self.normalize: + embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) + return embed + + def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: + """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates).""" + no_undefined_paths = all(p is not None for p in x.path) + no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout + if self.wav_cache is not None and no_undefined_paths and no_nullified_cond: + paths = [Path(p) for p in x.path if p is not None] + embed = self.wav_cache.get_embed_from_cache(paths, x) + else: + embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True) + if self.normalize: + embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) + return embed + + def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: + # Trying to limit as much as possible sync points when the cache is warm. + no_undefined_paths = all(p is not None for p in x.path) + if self.wav_cache is not None and no_undefined_paths: + assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + self.wav_cache.populate_embed_cache(paths, x) + if self.text_cache is not None and no_undefined_paths: + assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + self.text_cache.populate_embed_cache(paths, x) + return x + + def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Extract shared latent representation from either the wav or the text using CLAP.""" + # decide whether to use text embedding at train time or not + use_text_embed = random.random() < self.text_p + if self.training and not use_text_embed: + embed = self._get_wav_embedding(x) + empty_idx = torch.LongTensor([]) # we assume we always have the audio wav + else: + embed = self._get_text_embedding(x) + empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""]) + return embed, empty_idx + + +def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes: """Utility function for nullifying an attribute inside an ConditioningAttributes object. - If the condition is of type "wav", then nullify it using "nullify_condition". - If the condition is of any other type, set its' value to None. + If the condition is of type "wav", then nullify it using `nullify_condition` function. + If the condition is of any other type, set its value to None. Works in-place. """ - if condition_type not in ["text", "wav"]: + if condition_type not in ['text', 'wav', 'joint_embed']: raise ValueError( "dropout_condition got an unexpected condition type!" - f" expected 'wav' or 'text' but got '{condition_type}'" + f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'" ) if condition not in getattr(sample, condition_type): raise ValueError( "dropout_condition received an unexpected condition!" f" expected wav={sample.wav.keys()} and text={sample.text.keys()}" - f"but got '{condition}' of type '{condition_type}'!" + f" but got '{condition}' of type '{condition_type}'!" ) - if condition_type == "wav": - wav, length, path = sample.wav[condition] - sample.wav[condition] = nullify_wav(wav) + if condition_type == 'wav': + wav_cond = sample.wav[condition] + sample.wav[condition] = nullify_wav(wav_cond) + elif condition_type == 'joint_embed': + embed = sample.joint_embed[condition] + sample.joint_embed[condition] = nullify_joint_embed(embed) else: sample.text[condition] = None @@ -631,7 +1028,7 @@ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condi class DropoutModule(nn.Module): - """Base class for all dropout modules.""" + """Base module for all dropout modules.""" def __init__(self, seed: int = 1234): super().__init__() self.rng = torch.Generator() @@ -639,10 +1036,11 @@ class DropoutModule(nn.Module): class AttributeDropout(DropoutModule): - """Applies dropout with a given probability per attribute. This is different from the behavior of - ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example, - "artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout - where if "artist" is dropped "genre" must also be dropped. + """Dropout with a given probability per attribute. + This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes + to be dropped out separately. For example, "artist" can be dropped while "genre" remains. + This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" + must also be dropped. Args: p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: @@ -665,21 +1063,19 @@ class AttributeDropout(DropoutModule): def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: """ Args: - samples (tp.List[ConditioningAttributes]): List of conditions. + samples (list[ConditioningAttributes]): List of conditions. Returns: - tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None. + list[ConditioningAttributes]: List of conditions after certain attributes were set to None. """ if not self.training and not self.active_on_eval: return samples samples = deepcopy(samples) - for condition_type, ps in self.p.items(): # for condition types [text, wav] for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) if torch.rand(1, generator=self.rng).item() < p: for sample in samples: dropout_condition(sample, condition_type, condition) - return samples def __repr__(self): @@ -687,8 +1083,8 @@ class AttributeDropout(DropoutModule): class ClassifierFreeGuidanceDropout(DropoutModule): - """Applies Classifier Free Guidance dropout, meaning all attributes - are dropped with the same probability. + """Classifier Free Guidance dropout. + All attributes are dropped with the same probability. Args: p (float): Probability to apply condition dropout during training. @@ -701,9 +1097,9 @@ class ClassifierFreeGuidanceDropout(DropoutModule): def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: """ Args: - samples (tp.List[ConditioningAttributes]): List of conditions. + samples (list[ConditioningAttributes]): List of conditions. Returns: - tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None. + list[ConditioningAttributes]: List of conditions after all attributes were set to None. """ if not self.training: return samples @@ -715,12 +1111,10 @@ class ClassifierFreeGuidanceDropout(DropoutModule): # nullify conditions of all attributes samples = deepcopy(samples) - for condition_type in ["wav", "text"]: for sample in samples: for condition in sample.attributes[condition_type]: dropout_condition(sample, condition_type, condition) - return samples def __repr__(self): @@ -728,29 +1122,25 @@ class ClassifierFreeGuidanceDropout(DropoutModule): class ConditioningProvider(nn.Module): - """Main class to provide conditions given all the supported conditioners. + """Prepare and provide conditions given all the supported conditioners. Args: conditioners (dict): Dictionary of conditioners. - merge_text_conditions_p (float, optional): Probability to merge all text sources - into a single text condition. Defaults to 0. - drop_desc_p (float, optional): Probability to drop the original description - when merging all text sources into a single text condition. Defaults to 0. - device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types. + device (torch.device or str, optional): Device for conditioners and output condition types. """ - def __init__( - self, - conditioners: tp.Dict[str, BaseConditioner], - merge_text_conditions_p: float = 0, - drop_desc_p: float = 0, - device: tp.Union[torch.device, str] = "cpu", - ): + def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): super().__init__() self.device = device - self.merge_text_conditions_p = merge_text_conditions_p - self.drop_desc_p = drop_desc_p self.conditioners = nn.ModuleDict(conditioners) + @property + def joint_embed_conditions(self): + return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)] + + @property + def has_joint_embed_conditions(self): + return len(self.joint_embed_conditions) > 0 + @property def text_conditions(self): return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] @@ -769,33 +1159,36 @@ class ConditioningProvider(nn.Module): This will return a dict matching conditioner names to their arbitrary tokenized representations. Args: - inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing + inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing text and wav conditions. """ - assert all([type(x) == ConditioningAttributes for x in inputs]), \ - "got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \ + assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( + "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", f" but types were {set([type(x) for x in inputs])}" + ) output = {} text = self._collate_text(inputs) wavs = self._collate_wavs(inputs) + joint_embeds = self._collate_joint_embeds(inputs) - assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \ - f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}" + assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), ( + f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", + f"got {text.keys(), wavs.keys(), joint_embeds.keys()}" + ) - for attribute, batch in chain(text.items(), wavs.items()): + for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()): output[attribute] = self.conditioners[attribute].tokenize(batch) return output def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: - """Compute pairs of `(embedding, mask)` using the configured conditioners - and the tokenized representations. The output is for example: - - { - "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), - "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), - ... - } + """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. + The output is for example: + { + "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), + "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), + ... + } Args: tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. @@ -820,51 +1213,22 @@ class ConditioningProvider(nn.Module): "genre": ["Rock", "Hip-hop"], "description": ["A rock song with a guitar solo", "A hip-hop verse"] } - """ - batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) - - def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0): - def is_valid(k, v): - k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument'] - v_valid = v is not None and isinstance(v, (int, float, str, list)) - return k_valid and v_valid - - def process_value(v): - if isinstance(v, (int, float, str)): - return v - if isinstance(v, list): - return ", ".join(v) - else: - RuntimeError(f"unknown type for text value! ({type(v), v})") - - desc = cond.text['description'] - meta_data = "" - if random.uniform(0, 1) < merge_text_conditions_p: - meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)] - random.shuffle(meta_pairs) - meta_data = ". ".join(meta_pairs) - desc = desc if not random.uniform(0, 1) < drop_desc_p else None - - if desc is None: - desc = meta_data if len(meta_data) > 1 else None - else: - desc = desc.rstrip('.') + ". " + meta_data - cond.text['description'] = desc.strip() if desc else None - - if self.training and self.merge_text_conditions_p: - for sample in samples: - _merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p) + Args: + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. + Returns: + dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. + """ + out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) texts = [x.text for x in samples] for text in texts: for condition in self.text_conditions: - batch_per_attribute[condition].append(text[condition]) - - return batch_per_attribute + out[condition].append(text[condition]) + return out - def _collate_wavs(self, samples: tp.List[ConditioningAttributes]): + def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]: """Generate a dict where the keys are attributes by which we fetch similar wavs, - and the values are Tensors of wavs according to said attribtues. + and the values are Tensors of wavs according to said attributes. *Note*: by the time the samples reach this function, each sample should have some waveform inside the "wav" attribute. It should be either: @@ -873,27 +1237,89 @@ class ConditioningProvider(nn.Module): 3. A null waveform due to it being dropped in a dropout module (nullified by dropout) Args: - samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples. + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. Returns: - dict: A dicionary mapping an attribute name to wavs. + dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. """ wavs = defaultdict(list) - lens = defaultdict(list) + lengths = defaultdict(list) + sample_rates = defaultdict(list) paths = defaultdict(list) - out = {} + seek_times = defaultdict(list) + out: tp.Dict[str, WavCondition] = {} for sample in samples: for attribute in self.wav_conditions: - wav, length, path = sample.wav[attribute] - wavs[attribute].append(wav.flatten()) - lens[attribute].append(length) - paths[attribute].append(path) + wav, length, sample_rate, path, seek_time = sample.wav[attribute] + assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" + assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" + # mono-channel conditioning + wav = wav.mean(1, keepdim=True) # [1, 1, T] + wavs[attribute].append(wav.flatten()) # [T] + lengths[attribute].append(length) + sample_rates[attribute].extend(sample_rate) + paths[attribute].extend(path) + seek_times[attribute].extend(seek_time) # stack all wavs to a single tensor for attribute in self.wav_conditions: stacked_wav, _ = collate(wavs[attribute], dim=0) - out[attribute] = WavCondition(stacked_wav.unsqueeze(1), - torch.cat(lens['self_wav']), paths[attribute]) # type: ignore + out[attribute] = WavCondition( + stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute], + paths[attribute], seek_times[attribute]) + + return out + + def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]: + """Generate a dict where the keys are attributes by which we compute joint embeddings, + and the values are Tensors of pre-computed embeddings and the corresponding text attributes. + + Args: + samples (list[ConditioningAttributes]): List of ConditioningAttributes samples. + Returns: + A dictionary mapping an attribute name to joint embeddings. + """ + texts = defaultdict(list) + wavs = defaultdict(list) + lengths = defaultdict(list) + sample_rates = defaultdict(list) + paths = defaultdict(list) + seek_times = defaultdict(list) + channels: int = 0 + + out = {} + for sample in samples: + for attribute in self.joint_embed_conditions: + wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute] + assert wav.dim() == 3 + if channels == 0: + channels = wav.size(1) + else: + assert channels == wav.size(1), "not all audio has same number of channels in batch" + assert wav.size(0) == 1, "Expecting single-wav batch in the collate method" + wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T] + wavs[attribute].append(wav) + texts[attribute].extend(text) + lengths[attribute].append(length) + sample_rates[attribute].extend(sample_rate) + paths[attribute].extend(path) + seek_times[attribute].extend(seek_time) + + for attribute in self.joint_embed_conditions: + stacked_texts = texts[attribute] + stacked_paths = paths[attribute] + stacked_seek_times = seek_times[attribute] + stacked_wavs = pad_sequence(wavs[attribute]).to(self.device) + stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels) + stacked_sample_rates = sample_rates[attribute] + stacked_lengths = torch.cat(lengths[attribute]).to(self.device) + assert stacked_lengths.size(0) == stacked_wavs.size(0) + assert len(stacked_sample_rates) == stacked_wavs.size(0) + assert len(stacked_texts) == stacked_wavs.size(0) + out[attribute] = JointEmbedCondition( + text=stacked_texts, wav=stacked_wavs, + length=stacked_lengths, sample_rate=stacked_sample_rates, + path=stacked_paths, seek_time=stacked_seek_times) return out @@ -920,7 +1346,7 @@ class ConditionFuser(StreamingModule): super().__init__() assert all( [k in self.FUSING_METHODS for k in fuse2cond.keys()] - ), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}" + ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" self.cross_attention_pos_emb = cross_attention_pos_emb self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond @@ -931,16 +1357,16 @@ class ConditionFuser(StreamingModule): def forward( self, - input: Tensor, + input: torch.Tensor, conditions: tp.Dict[str, ConditionType] - ) -> tp.Tuple[Tensor, tp.Optional[Tensor]]: + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Fuse the conditions to the provided model input. Args: - input (Tensor): Transformer input. - conditions (tp.Dict[str, ConditionType]): Dict of conditions. + input (torch.Tensor): Transformer input. + conditions (dict[str, ConditionType]): Dict of conditions. Returns: - tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input + tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input after the conditions have been fused. The second output tensor is the tensor used for cross-attention or None if no cross attention inputs exist. """ @@ -959,16 +1385,16 @@ class ConditionFuser(StreamingModule): cross_attention_output = None for cond_type, (cond, cond_mask) in conditions.items(): op = self.cond2fuse[cond_type] - if op == "sum": + if op == 'sum': input += cond - elif op == "input_interpolate": - cond = rearrange(cond, "b t d -> b d t") + elif op == 'input_interpolate': + cond = einops.rearrange(cond, "b t d -> b d t") cond = F.interpolate(cond, size=input.shape[1]) - input += rearrange(cond, "b d t -> b t d") - elif op == "prepend": + input += einops.rearrange(cond, "b d t -> b t d") + elif op == 'prepend': if first_step: input = torch.cat([cond, input], dim=1) - elif op == "cross": + elif op == 'cross': if cross_attention_output is not None: cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) else: diff --git a/audiocraft/modules/conv.py b/audiocraft/modules/conv.py index 972938ab84712eb06e1b10cea25444eee51d6637..c5d140c87386687b27f18d3cf3b04bc15b5cdba1 100644 --- a/audiocraft/modules/conv.py +++ b/audiocraft/modules/conv.py @@ -11,7 +11,7 @@ import warnings import torch from torch import nn from torch.nn import functional as F -from torch.nn.utils import spectral_norm, weight_norm +from torch.nn.utils.parametrizations import spectral_norm, weight_norm CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', diff --git a/audiocraft/modules/diffusion_schedule.py b/audiocraft/modules/diffusion_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..74ca6e3f2e7c4ff904d96dade315b0b46856778d --- /dev/null +++ b/audiocraft/modules/diffusion_schedule.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Functions for Noise Schedule, defines diffusion process, reverse process and data processor. +""" + +from collections import namedtuple +import random +import typing as tp +import julius +import torch + +TrainingItem = namedtuple("TrainingItem", "noisy noise step") + + +def betas_from_alpha_bar(alpha_bar): + alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]]) + return 1 - alphas + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + return x + + def return_sample(self, z: torch.Tensor): + """Project back from diffusion space to the actual sample space.""" + return z + + +class MultiBandProcessor(SampleProcessor): + """ + MultiBand sample processor. The input audio is splitted across + frequency bands evenly distributed in mel-scale. + + Each band will be rescaled to match the power distribution + of Gaussian noise in that band, using online metrics + computed on the first few samples. + + Args: + n_bands (int): Number of mel-bands to split the signal over. + sample_rate (int): Sample rate of the audio. + num_samples (int): Number of samples to use to fit the rescaling + for each band. The processor won't be stable + until it has seen that many samples. + power_std (float or list/tensor): The rescaling factor computed to match the + power of Gaussian noise in each band is taken to + that power, i.e. `1.` means full correction of the energy + in each band, and values less than `1` means only partial + correction. Can be used to balance the relative importance + of low vs. high freq in typical audio signals. + """ + def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, + num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.): + super().__init__() + self.n_bands = n_bands + self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands) + self.num_samples = num_samples + self.power_std = power_std + if isinstance(power_std, list): + assert len(power_std) == n_bands + power_std = torch.tensor(power_std) + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(n_bands)) + self.register_buffer('sum_x2', torch.zeros(n_bands)) + self.register_buffer('sum_target_x2', torch.zeros(n_bands)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + self.sum_target_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + return std + + @property + def target_std(self): + target_std = self.sum_target_x2 / self.counts + return target_std + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + bands = self.split_bands(x) + if self.counts.item() < self.num_samples: + ref_bands = self.split_bands(torch.randn_like(x)) + self.counts += len(x) + self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1) + self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1) + self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1) + return bands.sum(dim=0) + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + bands = self.split_bands(x) + rescale = (self.std / self.target_std) ** self.power_std + bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1) + return bands.sum(dim=0) + + +class NoiseSchedule: + """Noise schedule for diffusion. + + Args: + beta_t0 (float): Variance of the first diffusion step. + beta_t1 (float): Variance of the last diffusion step. + beta_exp (float): Power schedule exponent + num_steps (int): Number of diffusion step. + variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde" + clip (float): clipping value for the denoising steps + rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1) + repartition (str): shape of the schedule only power schedule is supported + sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution + noise_scale (float): Scaling factor for the noise + """ + def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta', + clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1, + repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None, + sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs): + + self.beta_t0 = beta_t0 + self.beta_t1 = beta_t1 + self.variance = variance + self.num_steps = num_steps + self.clip = clip + self.sample_processor = sample_processor + self.rescale = rescale + self.n_bands = n_bands + self.noise_scale = noise_scale + assert n_bands is None + if repartition == "power": + self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps, + device=device, dtype=torch.float) ** beta_exp + else: + raise RuntimeError('Not implemented') + self.rng = random.Random(1234) + + def get_beta(self, step: tp.Union[int, torch.Tensor]): + if self.n_bands is None: + return self.betas[step] + else: + return self.betas[:, step] # [n_bands, len(step)] + + def get_initial_noise(self, x: torch.Tensor): + if self.n_bands is None: + return torch.randn_like(x) + return torch.randn((x.size(0), self.n_bands, x.size(2))) + + def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor: + """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step.""" + if step is None: + return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands + if type(step) is int: + return (1 - self.betas[:step + 1]).prod() + else: + return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1) + + def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem: + """Create a noisy data item for diffusion model training: + + Args: + x (torch.Tensor): clean audio data torch.tensor(bs, 1, T) + tensor_step (bool): If tensor_step = false, only one step t is sample, + the whole batch is diffused to the same step and t is int. + If tensor_step = true, t is a tensor of size (x.size(0),) + every element of the batch is diffused to a independently sampled. + """ + step: tp.Union[int, torch.Tensor] + if tensor_step: + bs = x.size(0) + step = torch.randint(0, self.num_steps, size=(bs,), device=x.device) + else: + step = self.rng.randrange(self.num_steps) + alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1] + + x = self.sample_processor.project_sample(x) + noise = torch.randn_like(x) + noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale + return TrainingItem(noisy, noise, step) + + def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None, + condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): + """Full ddpm reverse process. + + Args: + model (nn.Module): Diffusion model. + initial (tensor): Initial Noise. + condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation). + return_list (bool): Whether to return the whole process or only the sampled point. + """ + alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) + current = initial + iterates = [initial] + for step in range(self.num_steps)[::-1]: + with torch.no_grad(): + estimate = model(current, step, condition=condition).sample + alpha = 1 - self.betas[step] + previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() + previous_alpha_bar = self.get_alpha_bar(step=step - 1) + if step == 0: + sigma2 = 0 + elif self.variance == 'beta': + sigma2 = 1 - alpha + elif self.variance == 'beta_tilde': + sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) + elif self.variance == 'none': + sigma2 = 0 + else: + raise ValueError(f'Invalid variance type {self.variance}') + + if sigma2 > 0: + previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale + if self.clip: + previous = previous.clamp(-self.clip, self.clip) + current = previous + alpha_bar = previous_alpha_bar + if step == 0: + previous *= self.rescale + if return_list: + iterates.append(previous.cpu()) + + if return_list: + return iterates + else: + return self.sample_processor.return_sample(previous) + + def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None, + condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): + """Reverse process that only goes through Markov chain states in step_list.""" + if step_list is None: + step_list = list(range(1000))[::-50] + [0] + alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) + alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu() + betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled) + current = initial * self.noise_scale + iterates = [current] + for idx, step in enumerate(step_list[:-1]): + with torch.no_grad(): + estimate = model(current, step, condition=condition).sample * self.noise_scale + alpha = 1 - betas_subsampled[-1 - idx] + previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() + previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1]) + if step == step_list[-2]: + sigma2 = 0 + previous_alpha_bar = torch.tensor(1.0) + else: + sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) + if sigma2 > 0: + previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale + if self.clip: + previous = previous.clamp(-self.clip, self.clip) + current = previous + alpha_bar = previous_alpha_bar + if step == 0: + previous *= self.rescale + if return_list: + iterates.append(previous.cpu()) + if return_list: + return iterates + else: + return self.sample_processor.return_sample(previous) diff --git a/audiocraft/modules/rope.py b/audiocraft/modules/rope.py index 4b8c70b9aba28eeb53d12ddc3de8852492847808..c12cee0954f27c45d79627771fdf7fa9fc10dfcc 100644 --- a/audiocraft/modules/rope.py +++ b/audiocraft/modules/rope.py @@ -18,7 +18,7 @@ class XPos(nn.Module): dim (int): Embedding dimension. smoothing (float): Smoothing factor applied to the decay rates. base_scale (int): Base decay rate, given in terms of scaling time. - device (torch.device or None): Device on which to initialize the module. + device (torch.device, optional): Device on which to initialize the module. dtype (torch.dtype): dtype to use to generate the embedding. """ def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512, @@ -36,8 +36,7 @@ class XPos(nn.Module): self.decay: tp.Optional[torch.Tensor] = None def get_decay(self, start: int, end: int): - """Create complex decay tensor, cache values for fast computation. - """ + """Create complex decay tensor, cache values for fast computation.""" if self.decay is None or end > self.decay.shape[0]: assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker. idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype) @@ -55,7 +54,7 @@ class RotaryEmbedding(nn.Module): max_period (float): Maximum period of the rotation frequencies. xpos (bool): Use xPos, applies an exponential decay to rotation matrix. scale (float): Scale of positional embedding, set to 0 to deactivate. - device (torch.device or None): Device on which to initialize the module. + device (torch.device, optional): Device on which to initialize the module. dtype (torch.dtype): dtype to use to generate the embedding. """ def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False, @@ -74,8 +73,7 @@ class RotaryEmbedding(nn.Module): self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None def get_rotation(self, start: int, end: int): - """Create complex rotation tensor, cache values for fast computation. - """ + """Create complex rotation tensor, cache values for fast computation.""" if self.rotation is None or end > self.rotation.shape[0]: assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker. idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype) @@ -83,14 +81,16 @@ class RotaryEmbedding(nn.Module): self.rotation = torch.polar(torch.ones_like(angles), angles) return self.rotation[start:end] - def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False): - """Apply rope rotation to query or key tensor. - """ - T = x.shape[1] - rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2) + def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False): + """Apply rope rotation to query or key tensor.""" + T = x.shape[time_dim] + target_shape = [1] * x.dim() + target_shape[time_dim] = T + target_shape[-1] = -1 + rotation = self.get_rotation(start, start + T).view(target_shape) if self.xpos: - decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2) + decay = self.xpos.get_decay(start, start + T).view(target_shape) else: decay = 1.0 @@ -99,26 +99,27 @@ class RotaryEmbedding(nn.Module): x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2)) scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale) - x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2) + x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x) return x_out.type_as(x) - def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0): + def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1): """ Apply rope rotation to both query and key tensors. Supports streaming mode, in which query and key are not expected to have the same shape. - In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but + In streaming mode, key will be of length [P + C] with P the cached past timesteps, but query will be [C] (typically C == 1). Args: query (torch.Tensor): Query to rotate. key (torch.Tensor): Key to rotate. start (int): Start index of the sequence for time offset. + time_dim (int): which dimension represent the time steps. """ - query_timesteps = query.shape[1] - key_timesteps = key.shape[1] + query_timesteps = query.shape[time_dim] + key_timesteps = key.shape[time_dim] streaming_offset = key_timesteps - query_timesteps - query_out = self.rotate(query, start + streaming_offset) - key_out = self.rotate(key, start, invert_decay=True) + query_out = self.rotate(query, start + streaming_offset, time_dim) + key_out = self.rotate(key, start, time_dim, invert_decay=True) return query_out, key_out diff --git a/audiocraft/modules/transformer.py b/audiocraft/modules/transformer.py index e69cca829d774d0b8b36c0de9b7924373da81b43..e8100a4cff720739c4f870632cd1ccda26a2620c 100644 --- a/audiocraft/modules/transformer.py +++ b/audiocraft/modules/transformer.py @@ -35,8 +35,8 @@ def set_efficient_attention_backend(backend: str = 'torch'): _efficient_attention_backend = backend -def _get_attention_time_dimension() -> int: - if _efficient_attention_backend == 'torch': +def _get_attention_time_dimension(memory_efficient: bool) -> int: + if _efficient_attention_backend == 'torch' and memory_efficient: return 2 else: return 1 @@ -89,11 +89,11 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) -def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers""" +def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.""" if n_rep == 1: return x - if _efficient_attention_backend == 'torch': + if _efficient_attention_backend == 'torch' and memory_efficient: bs, n_kv_heads, slen, head_dim = x.shape return ( x[:, :, None, :, :] @@ -111,14 +111,14 @@ def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class LayerScale(nn.Module): """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). - This rescales diagonaly the residual outputs close to 0, with a learnt scale. + This rescales diagonally the residual outputs close to 0, with a learnt scale. Args: channels (int): Number of channels. init (float): Initial scale. channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. - device (torch.device or None): Device on which to initialize the module. - dtype (torch.dtype or None): dtype to use to initialize the module. + device (torch.device or str, optional): Device on which to initialize the module. + dtype (torch.dtype, optional): dtype to use to initialize the module. """ def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True, device=None, dtype=None): @@ -144,22 +144,22 @@ class StreamingMultiheadAttention(StreamingModule): dropout (float): Dropout level. bias (bool): Use bias in projections. causal (bool): Causal mask applied automatically. - past_context (int or None): Receptive field for the causal mask, infinite if None. + past_context (int, optional): Receptive field for the causal mask, infinite if None. custom (bool): Use custom MHA implementation, for testing / benchmarking. memory_efficient (bool): Use xformers based memory efficient attention. attention_as_float32 (bool): Perform the attention as float32 (especially important with memory_efficient as autocast won't do this automatically). - rope (`RotaryEmbedding` or None): Rope embedding to use. + rope (`RotaryEmbedding`, optional): Rope embedding to use. cross_attention: Should be true when used as a cross attention. All keys and values must be available at once, streaming is only for the queries. Cannot be used with `causal` or `rope` (as it wouldn't make sens to - intepret the time steps in the keys relative to those in the queries). + interpret the time steps in the keys relative to those in the queries). safe_streaming (bool): Bug fix, will go away with xformers update. qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product. kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). This will lead to faster decoding time on A100 or other GPUs with tensorcore. - device (torch.device or None): Sevice on which to initialize. - dtype (torch.dtype or None): dtype to use. + device (torch.device, optional): Device on which to initialize. + dtype (torch.dtype, optional): dtype to use. """ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, @@ -234,14 +234,14 @@ class StreamingMultiheadAttention(StreamingModule): # Return a causal mask, accounting for potentially stored past keys/values # We actually return a bias for the attention score, as this has the same # convention both in the builtin MHA in Pytorch, and Xformers functions. - time_dim = _get_attention_time_dimension() + time_dim = _get_attention_time_dimension(self.memory_efficient) if self.memory_efficient: from xformers.ops import LowerTriangularMask if current_steps == 1: # If we only have one step, then we do not need a mask. return None elif 'past_keys' in self._streaming_state: - raise RuntimeError('Not supported at the moment') + raise RuntimeError("Not supported at the moment") else: # Then we can safely use a lower triangular mask return LowerTriangularMask() @@ -264,7 +264,7 @@ class StreamingMultiheadAttention(StreamingModule): torch.full([], float('-inf'), device=device, dtype=dtype)) def _complete_kv(self, k, v): - time_dim = _get_attention_time_dimension() + time_dim = _get_attention_time_dimension(self.memory_efficient) if self.cross_attention: # With cross attention we assume all keys and values # are already available, and streaming is with respect @@ -298,8 +298,7 @@ class StreamingMultiheadAttention(StreamingModule): return nk, nv def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): - # TODO: fix and verify layout. - assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.' + time_dim = _get_attention_time_dimension(self.memory_efficient) # Apply rope embeddings to query and key tensors. assert self.rope is not None if 'past_keys' in self._streaming_state: @@ -311,16 +310,16 @@ class StreamingMultiheadAttention(StreamingModule): else: past_context_offset = 0 streaming_offset = past_context_offset + past_keys_offset - return self.rope.rotate_qk(query, key, start=streaming_offset) + return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim) def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask=None, need_weights=False, attn_mask=None, average_attn_weights=True, is_causal=False): assert attn_mask is None - assert not is_causal, ("new param added in torch 2.0.1 not supported, " + assert not is_causal, ("New param added in torch 2.0.1 not supported, " "use the causal args in the constructor.") - time_dim = _get_attention_time_dimension() + time_dim = _get_attention_time_dimension(self.memory_efficient) if time_dim == 2: layout = "b h t d" else: @@ -394,8 +393,8 @@ class StreamingMultiheadAttention(StreamingModule): q, k = self._apply_rope(q, k) k, v = self._complete_kv(k, v) if self.kv_repeat > 1: - k = expand_repeated_kv(k, self.kv_repeat) - v = expand_repeated_kv(v, self.kv_repeat) + k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient) + v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient) if self.attention_as_float32: q, k, v = [x.float() for x in [q, k, v]] if self.memory_efficient: @@ -455,7 +454,7 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer): bias_ff (bool): Use bias for FF. bias_attn (bool): Use bias for MHA. causal (bool): Causal mask applied automatically. - past_context (int or None): Receptive field for the causal mask, infinite if None. + past_context (int, optional): Receptive field for the causal mask, infinite if None. custom (bool): Use custom MHA implementation, for testing / benchmarking. memory_efficient (bool): Use xformers based memory efficient attention. attention_as_float32 (bool): Perform the attention as float32 @@ -465,15 +464,15 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer): cross_attention (bool): If True, expect to get secondary input for cross-attention. Cross attention will use the default MHA, as it typically won't require special treatment. - layer_scale (float or None): If not None, LayerScale will be used with + layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale. - rope (`RotaryEmbedding` or None): Rope embedding to use. - attention_dropout (float or None): If not None, separate the value of the dimension dropout + rope (`RotaryEmbedding`, optional): Rope embedding to use. + attention_dropout (float, optional): If not None, separate the value of the dimension dropout in FFN and of the attention dropout. kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). This will lead to faster decoding time on A100 or other GPUs with tensorcore. - device (torch.device or None): Device on which to initialize. - dtype (torch.dtype or None): dtype to use. + device (torch.device, optional): Device on which to initialize. + dtype (torch.dtype, optional): dtype to use. **kwargs: See `nn.TransformerEncoderLayer`. """ def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, @@ -576,30 +575,30 @@ class StreamingTransformer(StreamingModule): bias_ff (bool): Use bias for FF. bias_attn (bool): Use bias for MHA. causal (bool): Causal mask applied automatically. - past_context (int or None): Receptive field for the causal mask, infinite if None. + past_context (int, optional): Receptive field for the causal mask, infinite if None. custom (bool): Use custom MHA implementation, for testing / benchmarking. memory_efficient (bool): Use xformers based memory efficient attention. attention_as_float32 (bool): Perform the attention as float32 (especially important with memory_efficient as autocast won't do this automatically). cross_attention (bool): If True, expect to get secondary input for cross-attention. - layer_scale (float or None): If not None, LayerScale will be used + layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale. positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope). max_period (float): Maximum period of the time embedding. positional_scale (float): Scale of positional embedding, set to 0 to deactivate. xpos (bool): Apply xpos exponential decay to positional embedding (rope only). - lr (float or None): learning rate override through the `make_optim_group` API. - weight_decay (float or None): Weight_decay override through the `make_optim_group` API. + lr (float, optional): learning rate override through the `make_optim_group` API. + weight_decay (float, optional): Weight_decay override through the `make_optim_group` API. layer_class: (subclass of `StreamingTransformerLayer): class to use - to initialize the layers, allowing further customization outside of Audiocraft. + to initialize the layers, allowing further customization outside of AudioCraft. checkpointing (str): Checkpointing strategy to reduce memory usage. No checkpointing if set to 'none'. Per layer checkpointing using PyTorch if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, minimal memory usage, but maximal runtime). Finally, `xformers_default` provide a policy for opting-out some operations of the checkpointing like linear layers and attention, providing a middle ground between speed and memory. - device (torch.device or None): Device on which to initialize. - dtype (torch.dtype or None): dtype to use. + device (torch.device, optional): Device on which to initialize. + dtype (torch.dtype, optional): dtype to use. **kwargs: See `nn.TransformerEncoderLayer`. """ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, @@ -649,7 +648,6 @@ class StreamingTransformer(StreamingModule): # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the # backward hook inside of FSDP... layer._magma_checkpointed = True # type: ignore - assert layer.layer_drop == 0., "Need further checking" # type: ignore def _apply_layer(self, layer, *args, **kwargs): method = self.checkpointing @@ -713,7 +711,7 @@ class StreamingTransformer(StreamingModule): return group -# special attention attention related function +# special attention related function def _verify_xformers_memory_efficient_compat(): try: diff --git a/audiocraft/quantization/core_vq.py b/audiocraft/quantization/core_vq.py index e1896bb1788a945a1f7be6369abb255ecf72c7a0..6aaa3b077c53b413e2b2a904ac7e769d1c623b36 100644 --- a/audiocraft/quantization/core_vq.py +++ b/audiocraft/quantization/core_vq.py @@ -75,7 +75,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10): return means, bins -def orthgonal_loss_fn(t): +def orthogonal_loss_fn(t): # eq (2) from https://arxiv.org/abs/2112.00384 n = t.shape[0] normed_codes = l2norm(t) @@ -237,7 +237,7 @@ class VectorQuantization(nn.Module): orthogonal_reg_weight (float): Orthogonal regularization weights. orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. orthogonal_reg_max_codes (optional int): Maximum number of codes to consider - for orthogonal regulariation. + for orthogonal regularization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. @@ -340,7 +340,7 @@ class VectorQuantization(nn.Module): rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] codebook = codebook[rand_ids] - orthogonal_reg_loss = orthgonal_loss_fn(codebook) + orthogonal_reg_loss = orthogonal_loss_fn(codebook) loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight quantize = self.project_out(quantize) @@ -371,11 +371,16 @@ class ResidualVectorQuantization(nn.Module): for i, layer in enumerate(self.layers[:n_q]): quantized, indices, loss = layer(residual) + quantized = quantized.detach() residual = residual - quantized quantized_out = quantized_out + quantized all_indices.append(indices) all_losses.append(loss) + if self.training: + # Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25 + quantized_out = x + (quantized_out - x).detach() + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses diff --git a/audiocraft/utils/cache.py b/audiocraft/utils/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba017a761a29c44d3385e0b483877cb4a8d1ec1 --- /dev/null +++ b/audiocraft/utils/cache.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from concurrent.futures import ThreadPoolExecutor +from collections import deque +from functools import partial +from hashlib import sha1 +import logging +from pathlib import Path +import sys +import typing as tp +import zipfile + +import flashy +import torch + + +logger = logging.getLogger(__name__) + + +def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor: + """Utility function for the EmbeddingCache, returning the full embedding without any chunking. + This method can be used in case there is no need in extracting a chunk of the full embedding + read from the cache. + + Args: + full_embed (torch.Tensor): The full embedding. + x (any): Batch object from which the full embedding is derived. + idx (torch.Tensor): Index of object to consider in the batch object. + Returns: + full_embed (torch.Tensor): The full embedding + """ + return full_embed.to(device) + + +class EmbeddingCache: + """Cache around embeddings computation for faster execution. + The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API + to retrieve the pre-computed embeddings on full inputs and extract only a given chunk + using a user-provided function. When the cache is warm (all embeddings are pre-computed), + the EmbeddingCache allows for faster training as it removes the need of computing the embeddings. + Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint + and synchronization points in the forward calls. + + Args: + cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk. + device (str or torch.device): Device on which the embedding is returned. + compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute + the embedding from a given object and path. This user provided function can compute the + embedding from the provided object or using the provided path as entry point. The last parameter + specify the index corresponding to the current embedding in the object that can represent batch metadata. + extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract + the desired embedding chunk from the full embedding loaded from the cache. The last parameter + specify the index corresponding to the current embedding in the object that can represent batch metadata. + If not specified, will return the full embedding unmodified. + """ + def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device], + compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor], + extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None): + self.cache_path = Path(cache_path) + self.device = device + self._compute_embed_fn = compute_embed_fn + self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor] + if extract_embed_fn is not None: + self._extract_embed_fn = extract_embed_fn + else: + self._extract_embed_fn = partial(get_full_embed, device=device) + if self.cache_path is not None: + self.cache_path.mkdir(exist_ok=True, parents=True) + logger.info(f"Cache instantiated at: {self.cache_path}") + self.pool = ThreadPoolExecutor(8) + self.pool.__enter__() + self._current_batch_cache: dict = {} + self._memory_cache: dict = {} + + def _get_cache_path(self, path: tp.Union[Path, str]): + """Get cache path for the given file path.""" + sig = sha1(str(path).encode()).hexdigest() + return self.cache_path / sig + + @staticmethod + def _get_full_embed_from_cache(cache: Path): + """Loads full pre-computed embedding from the cache.""" + try: + embed = torch.load(cache, 'cpu') + except Exception as exc: + logger.error("Error loading %s: %r", cache, exc) + embed = None + return embed + + def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor: + """Get embedding from cache, computing and storing it to cache if not already cached. + The EmbeddingCache first tries to load the embedding from the in-memory cache + containing the pre-computed chunks populated through `populate_embed_cache`. + If not found, the full embedding is computed and stored on disk to be later accessed + to populate the in-memory cache, and the desired embedding chunk is extracted and returned. + + Args: + paths (list[Path or str]): List of paths from where the embeddings can be loaded. + x (any): Object from which the embedding is extracted. + """ + embeds = [] + for idx, path in enumerate(paths): + cache = self._get_cache_path(path) + if cache in self._current_batch_cache: + embed = self._current_batch_cache[cache] + else: + full_embed = self._compute_embed_fn(path, x, idx) + try: + with flashy.utils.write_and_rename(cache, pid=True) as f: + torch.save(full_embed.cpu(), f) + except Exception as exc: + logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc) + else: + logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape) + embed = self._extract_embed_fn(full_embed, x, idx) + embeds.append(embed) + embed = torch.stack(embeds, dim=0) + return embed + + def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None: + """Populate in-memory caches for embeddings reading from the embeddings stored on disk. + The in-memory caches consist in a cache for the full embedding and another cache for the + final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings + and reduce the IO footprint and synchronization points during forward passes. + + Args: + paths (list[Path]): List of paths from where the embeddings can be loaded. + x (any): Object from which the embedding is extracted. + """ + self._current_batch_cache.clear() + if self.cache_path is not None: + futures: list = [] + for path in paths: + assert path is not None, "Path is required for computation from cache" + cache = self._get_cache_path(path) + if cache in self._memory_cache or not cache.exists(): + futures.append(None) + else: + futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache)) + for idx, (path, future) in enumerate(zip(paths, futures)): + assert path is not None + cache = self._get_cache_path(path) + full_embed = None + if future is None: + if cache in self._memory_cache: + full_embed = self._memory_cache[cache] + else: + full_embed = future.result() + if full_embed is not None: + self._memory_cache[cache] = full_embed + full_embed = full_embed.to(self.device) + if full_embed is not None: + embed = self._extract_embed_fn(full_embed, x, idx) + self._current_batch_cache[cache] = embed + + +class CachedBatchWriter: + """Write pre computed caches for mini batches. This can + make loading a lot more efficient depending on your filesystem. + + Args: + cache_folder (Path): folder in which the cached minibatches + will be stored. + + Inside cache folder, the structure is the following: + `epoch_number / update_number.zip` + And the zip file contains one entry per batch item. + + It is possible to use the cache with a batch size smaller than + created with but obviously not larger. Make sure to call the + `start_epoch(epoch)` method for indicating changes of epochs. + + See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py` + for an example of how to warmup the cache. + """ + def __init__(self, cache_folder: Path): + self.cache_folder = cache_folder + self._current_epoch: tp.Optional[int] = None + self._current_index = 0 + + def start_epoch(self, epoch: int): + """Call at the beginning of each epoch. + """ + self._current_epoch = epoch + self._current_index = 0 + self._zip_path.parent.mkdir(exist_ok=True, parents=True) + + @staticmethod + def _get_zip_path(cache_folder: Path, epoch: int, index: int): + return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip" + + @property + def _zip_path(self): + assert self._current_epoch is not None + return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index) + + def save(self, *content): + """Save one mini batch. This function is distributed-aware + and will automatically merge all the items from the different + workers. + """ + all_contents = [] + for rank in range(flashy.distrib.world_size()): + their_content = flashy.distrib.broadcast_object(content, src=rank) + all_contents.append(their_content) + + if flashy.distrib.is_rank_zero(): + idx = 0 + with flashy.utils.write_and_rename(self._zip_path) as tmp: + with zipfile.ZipFile(tmp, 'w') as zf: + for content in all_contents: + for vals in zip(*content): + with zf.open(f'{idx}', 'w') as f: # type: ignore + torch.save(vals, f) + idx += 1 + flashy.distrib.barrier() + self._current_index += 1 + + +class CachedBatchLoader: + """Loader for cached mini-batches dumped with `CachedBatchWriter`. + + Args: + cache_folder (Path): folder in which the cached minibatches are stored. + batch_size (int): batch size (per GPU) expected. + num_workers (int): number of workers to use for loading. + min_length (int): minimum expected length for each epoch. If some + mini-batches are missing, and error is raised. + + This is iterable just like a regular DataLoader. + """ + + def __init__(self, cache_folder: Path, batch_size: int, + num_workers: int = 10, min_length: int = 1): + self.cache_folder = cache_folder + self.batch_size = batch_size + self.num_workers = num_workers + self.min_length = min_length + self._current_epoch: tp.Optional[int] = None + self.sampler = None # for compatibility with the regular DataLoader + + def __len__(self): + path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent + return len([p for p in path.iterdir() if p.suffix == ".zip"]) + + def start_epoch(self, epoch: int): + """Call at the beginning of each epoch. + """ + self._current_epoch = epoch + + def _zip_path(self, index: int): + assert self._current_epoch is not None + return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index) + + def _load_one(self, index: int): + zip_path = self._zip_path(index) + if not zip_path.exists(): + if index < self.min_length: + raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist") + + return None + mode = "rb" if sys.version_info >= (3, 9) else "r" + try: + with zipfile.ZipFile(zip_path, 'r') as zf: + rank = flashy.distrib.rank() + world_size = flashy.distrib.world_size() + root = zipfile.Path(zf) + items = list(root.iterdir()) + total_batch_size = self.batch_size * world_size + if len(items) < total_batch_size: + raise RuntimeError( + f"The cache can handle a max batch size of {len(items)}, " + f"but {total_batch_size} is needed.") + start = rank * self.batch_size + items = items[start: start + self.batch_size] + assert len(items) == self.batch_size + entries = [] + entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore + transposed = zip(*entries) + out = [] + for part in transposed: + assert len(part) > 0 + if isinstance(part[0], torch.Tensor): + out.append(torch.stack(part)) + else: + assert isinstance(part, torch.Tensor) + out.append(part) + return out + except Exception: + logger.error("Error when reading zip path %s", zip_path) + raise + + def __iter__(self): + """This will yields tuples, exactly as provided to the + `CachedBatchWriter.save` method. + """ + pool = ThreadPoolExecutor(self.num_workers) + next_index = 0 + queue = deque() + + def _get_next(): + nonlocal next_index + r = queue.popleft().result() + if r is None: + return None + else: + queue.append(pool.submit(self._load_one, next_index)) + next_index += 1 + return r + + with pool: + # fill the buffer of fetching jobs. + for _ in range(2 * self.num_workers): + queue.append(pool.submit(self._load_one, next_index)) + next_index += 1 + while True: + batch = _get_next() + if batch is None: + return + yield batch diff --git a/audiocraft/utils/cluster.py b/audiocraft/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3380d031739d473fb859c76b9c25350f47fa77e8 --- /dev/null +++ b/audiocraft/utils/cluster.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility functions for SLURM configuration and cluster settings. +""" + +from enum import Enum +import os +import socket +import typing as tp + +import omegaconf + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + LOCAL_DARWIN = "darwin" + DEFAULT = "default" # used for any other cluster. + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + fqdn = socket.getfqdn() + if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn): + return ClusterType.AWS + + if fqdn.endswith(".fair"): + return ClusterType.FAIR + + if fqdn.endswith(".facebook.com"): + return ClusterType.RSC + + if uname.sysname == "Darwin": + return ClusterType.LOCAL_DARWIN + + return ClusterType.DEFAULT + + +def get_cluster_type( + cluster_type: tp.Optional[ClusterType] = None, +) -> tp.Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_slurm_parameters( + cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None +) -> omegaconf.DictConfig: + """Update SLURM parameters in configuration based on cluster type. + If the cluster type is not specify, it infers it automatically. + """ + from ..environment import AudioCraftEnvironment + cluster_type = get_cluster_type(cluster_type) + # apply cluster-specific adjustments + if cluster_type == ClusterType.AWS: + cfg["mem_per_gpu"] = None + cfg["constraint"] = None + cfg["setup"] = [] + elif cluster_type == ClusterType.RSC: + cfg["mem_per_gpu"] = None + cfg["setup"] = [] + cfg["constraint"] = None + cfg["partition"] = "learn" + slurm_exclude = AudioCraftEnvironment.get_slurm_exclude() + if slurm_exclude is not None: + cfg["exclude"] = slurm_exclude + return cfg diff --git a/audiocraft/utils/export.py b/audiocraft/utils/export.py index b513b52267f7bf5aae09282c15b0a2e20c8a8fee..28b214017d9ac23934b67e8254a96131cefa6501 100644 --- a/audiocraft/utils/export.py +++ b/audiocraft/utils/export.py @@ -11,46 +11,69 @@ Utility to export a training checkpoint to a lightweight release checkpoint. from pathlib import Path import typing as tp -from omegaconf import OmegaConf, DictConfig +from omegaconf import OmegaConf import torch +from audiocraft import __version__ -def _clean_lm_cfg(cfg: DictConfig): - OmegaConf.set_struct(cfg, False) - # This used to be set automatically in the LM solver, need a more robust solution - # for the future. - cfg['transformer_lm']['card'] = 2048 - cfg['transformer_lm']['n_q'] = 4 - # Experimental params no longer supported. - bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', - 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] - for name in bad_params: - del cfg['transformer_lm'][name] - OmegaConf.set_struct(cfg, True) - return cfg - - -def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): - sig = Path(checkpoint_path).parent.name - assert len(sig) == 8, "Not a valid Dora signature" + +def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + """Export only the best state from the given EnCodec checkpoint. This + should be used if you trained your own EnCodec model. + """ pkg = torch.load(checkpoint_path, 'cpu') new_pkg = { - 'best_state': pkg['ema']['state']['model'], + 'best_state': pkg['best_state']['model'], 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + 'version': __version__, + 'exported': True, } - out_file = Path(out_folder) / f'{sig}.th' + Path(out_file).parent.mkdir(exist_ok=True, parents=True) torch.save(new_pkg, out_file) return out_file -def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): - sig = Path(checkpoint_path).parent.name - assert len(sig) == 8, "Not a valid Dora signature" +def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): + """Export a compression model (potentially EnCodec) from a pretrained model. + This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. + Do not include the //pretrained/ prefix. For instance if you trained a model + with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. + + In that case, this will not actually include a copy of the model, simply the reference + to the model used. + """ + if Path(pretrained_encodec).exists(): + pkg = torch.load(pretrained_encodec) + assert 'best_state' in pkg + assert 'xp.cfg' in pkg + assert 'version' in pkg + assert 'exported' in pkg + else: + pkg = { + 'pretrained': pretrained_encodec, + 'exported': True, + 'version': __version__, + } + Path(out_file).parent.mkdir(exist_ok=True, parents=True) + torch.save(pkg, out_file) + + +def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + """Export only the best state from the given MusicGen or AudioGen checkpoint. + """ pkg = torch.load(checkpoint_path, 'cpu') + if pkg['fsdp_best_state']: + best_state = pkg['fsdp_best_state']['model'] + else: + assert pkg['best_state'] + best_state = pkg['best_state']['model'] new_pkg = { - 'best_state': pkg['fsdp_best_state']['model'], - 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) + 'best_state': best_state, + 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + 'version': __version__, + 'exported': True, } - out_file = Path(out_folder) / f'{sig}.th' + + Path(out_file).parent.mkdir(exist_ok=True, parents=True) torch.save(new_pkg, out_file) return out_file diff --git a/audiocraft/utils/export_legacy.py b/audiocraft/utils/export_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..b513b52267f7bf5aae09282c15b0a2e20c8a8fee --- /dev/null +++ b/audiocraft/utils/export_legacy.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility to export a training checkpoint to a lightweight release checkpoint. +""" + +from pathlib import Path +import typing as tp + +from omegaconf import OmegaConf, DictConfig +import torch + + +def _clean_lm_cfg(cfg: DictConfig): + OmegaConf.set_struct(cfg, False) + # This used to be set automatically in the LM solver, need a more robust solution + # for the future. + cfg['transformer_lm']['card'] = 2048 + cfg['transformer_lm']['n_q'] = 4 + # Experimental params no longer supported. + bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', + 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] + for name in bad_params: + del cfg['transformer_lm'][name] + OmegaConf.set_struct(cfg, True) + return cfg + + +def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): + sig = Path(checkpoint_path).parent.name + assert len(sig) == 8, "Not a valid Dora signature" + pkg = torch.load(checkpoint_path, 'cpu') + new_pkg = { + 'best_state': pkg['ema']['state']['model'], + 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + } + out_file = Path(out_folder) / f'{sig}.th' + torch.save(new_pkg, out_file) + return out_file + + +def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): + sig = Path(checkpoint_path).parent.name + assert len(sig) == 8, "Not a valid Dora signature" + pkg = torch.load(checkpoint_path, 'cpu') + new_pkg = { + 'best_state': pkg['fsdp_best_state']['model'], + 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) + } + out_file = Path(out_folder) / f'{sig}.th' + torch.save(new_pkg, out_file) + return out_file diff --git a/audiocraft/utils/extend.py b/audiocraft/utils/extend.py index 5c919a5cb740e14ca8751d68a0ab16d9400d35d6..5d75711c3a66a53f4d282a975f6fe853688c8412 100644 --- a/audiocraft/utils/extend.py +++ b/audiocraft/utils/extend.py @@ -179,7 +179,7 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap: descriptions=[text], melody_wavs=verse, sample_rate=sr, - progress=False, + progress=True, prompt=prompt_segment, ) # If user selects a prompt segment, use the prompt segment for all segments @@ -280,9 +280,10 @@ def load_font(font_name, font_size=16): if font is None: try: req = requests.get(font_name) - font = ImageFont.truetype(BytesIO(req.content), font_size) - except (FileNotFoundError, OSError): - print(f"Font not found: {font_name} Using default font\n") + font = ImageFont.truetype(BytesIO(req.content), font_size) + except (FileNotFoundError, OSError): + print(f"Font not found: {font_name} Using default font\n") + if font: print(f"Font loaded {font.getname()}") else: diff --git a/audiocraft/utils/utils.py b/audiocraft/utils/utils.py index 86e1448d065fa182ca69aae00d2f2a7eea55d8a4..16592cfe48a459cd7e7ff0477507aed9fc3749a1 100644 --- a/audiocraft/utils/utils.py +++ b/audiocraft/utils/utils.py @@ -5,9 +5,12 @@ # LICENSE file in the root directory of this source tree. from concurrent.futures import ProcessPoolExecutor -from functools import wraps +from contextlib import contextmanager +from functools import wraps, lru_cache import hashlib +import json import logging +from pathlib import Path import typing as tp import flashy @@ -20,6 +23,18 @@ from torch.nn.utils.rnn import pad_sequence logger = logging.getLogger(__name__) +def model_hash(model: torch.nn.Module) -> str: + """Return a model hash. This should allow us to track regressions in model init + from the logs of past experiments. + """ + hasher = hashlib.sha1() + for p in model.parameters(): + hasher.update(p.data.cpu().numpy().tobytes()) + return hasher.hexdigest() + + + + def dict_from_config(cfg: omegaconf.DictConfig) -> dict: """Convenience function to map an omegaconf configuration to a dictionary. @@ -172,7 +187,7 @@ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> t assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." final_length = lengths.max().item() if not max_len else max_len final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor - return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] + return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None] def hash_trick(word: str, vocab_size: int) -> int: @@ -232,3 +247,54 @@ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tens padded_tensors = padded_tensors.transpose(0, 1) padded_tensors = padded_tensors.transpose(1, dim + 1) return padded_tensors, lens + + +# TODO: Move to flashy? +def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu', + dtype: tp.Optional[torch.dtype] = None) -> tp.Any: + if isinstance(state, torch.Tensor): + if dtype is None or not state.is_floating_point(): + dtype = state.dtype + return state.detach().to(device=device, dtype=dtype, copy=True) + elif isinstance(state, dict): + return {k: copy_state(v, device, dtype) for k, v in state.items()} + elif isinstance(state, list): + return [copy_state(v, device, dtype) for v in state] + + +# TODO: Move to flashy? +@contextmanager +def swap_state(model, state, **kwargs): + old_state = copy_state(model.state_dict()) + model.load_state_dict(state, **kwargs) + try: + yield + finally: + model.load_state_dict(old_state) + + +@lru_cache(None) +def warn_once(logger, msg): + """Warn about a given message only once.""" + logger.warning(msg) + + +def is_jsonable(x: tp.Any): + """Check if an object can be serialized into a json:""" + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False + + +def load_clap_state_dict(clap_model, path: tp.Union[str, Path]): + """Wrapper around state dict loading of CLAP model + addressing compatibility issues between CLAP and AudioCraft + HuggingFace transformer version. + See: https://github.com/LAION-AI/CLAP/issues/118 + """ + from clap_module.factory import load_state_dict # type: ignore + pkg = load_state_dict(path) + pkg.pop('text_branch.embeddings.position_ids', None) + clap_model.model.load_state_dict(pkg) diff --git a/modules/file_utils.py b/modules/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c606a2917f4c4b11d59462828a576a47c1e3b05b --- /dev/null +++ b/modules/file_utils.py @@ -0,0 +1,91 @@ +# file_utils +import os +import shutil +from pathlib import Path + +def get_file_parts(file_path: str): + # Split the path into directory and filename + directory, filename = os.path.split(file_path) + + # Split the filename into name and extension + name, ext = os.path.splitext(filename) + + # Convert the extension to lowercase + new_ext = ext.lower() + return directory, filename, name, ext, new_ext + +def rename_file_to_lowercase_extension(file_path: str) -> str: + """ + Renames a file's extension to lowercase in place. + + Parameters: + file_path (str): The original file path. + + Returns: + str: The new file path with the lowercase extension. + + Raises: + OSError: If there is an error renaming the file (e.g., file not found, permissions issue). + """ + directory, filename, name, ext, new_ext = get_file_parts(file_path) + # If the extension changes, rename the file + if ext != new_ext: + new_filename = name + new_ext + new_file_path = os.path.join(directory, new_filename) + try: + os.rename(file_path, new_file_path) + print(f"Rename {file_path} to {new_file_path}\n") + except Exception as e: + print(f"os.rename failed: {e}. Falling back to binary copy operation.") + try: + # Read the file in binary mode and write it to new_file_path + with open(file_path, 'rb') as f: + data = f.read() + with open(new_file_path, 'wb') as f: + f.write(data) + print(f"Copied {file_path} to {new_file_path}\n") + # Optionally, remove the original file after copying + #os.remove(file_path) + except Exception as inner_e: + print(f"Failed to copy file from {file_path} to {new_file_path}: {inner_e}") + raise inner_e + return new_file_path + else: + return file_path + +def get_filename(file): + # extract filename from file object + filename = None + if file is not None: + filename = file.name + return filename + +def convert_title_to_filename(title): + # convert title to filename + filename = title.lower().replace(" ", "_").replace("/", "_") + return filename + +def get_filename_from_filepath(filepath): + file_name = os.path.basename(filepath) + file_base, file_extension = os.path.splitext(file_name) + return file_base, file_extension + +def delete_file(file_path: str) -> None: + """ + Deletes the specified file. + + Parameters: + file_path (str): The path to thefile to delete. + + Raises: + FileNotFoundError: If the file does not exist. + Exception: If there is an error deleting the file. + """ + try: + path = Path(file_path) + path.unlink() + print(f"Deleted original file: {file_path}") + except FileNotFoundError: + print(f"File not found: {file_path}") + except Exception as e: + print(f"Error deleting file: {e}") \ No newline at end of file diff --git a/modules/gradio.py b/modules/gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..de6971d7046a31085f69986afeedd93b62554148 --- /dev/null +++ b/modules/gradio.py @@ -0,0 +1,272 @@ +# modules.gradio +# holds updates and lost code from gradio changes +import os +import gradio as gr +import numpy as np +import PIL +import PIL.Image +import shutil +import subprocess +from tempfile import NamedTemporaryFile +from pathlib import Path + + +class MatplotlibBackendMananger: + def __enter__(self): + try: + import matplotlib + + self._original_backend = matplotlib.get_backend() + matplotlib.use("agg") + except ImportError: + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + import matplotlib + + matplotlib.use(self._original_backend) + except ImportError: + pass + +gr.utils.MatplotlibBackendMananger = MatplotlibBackendMananger + +def make_waveform( + audio: str | tuple[int, np.ndarray], + *, + bg_color: str = "#f3f4f6", + bg_image: str | None = None, + fg_alpha: float = 0.75, + bars_color: str | tuple[str, str] = ("#fbbf24", "#ea580c"), + bar_count: int = 50, + bar_width: float = 0.6, + animate: bool = False, + name: str = "", +) -> str: + """ + Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component. + Parameters: + audio: Audio file path or tuple of (sample_rate, audio_data) + bg_color: Background color of waveform (ignored if bg_image is provided) + bg_image: Background image of waveform + fg_alpha: Opacity of foreground waveform + bars_color: Color of waveform bars. Can be a single color or a tuple of (start_color, end_color) of gradient + bar_count: Number of bars in waveform + bar_width: Width of bars in waveform. 1 represents full width, 0.5 represents half width, etc. + animate: If true, the audio waveform overlay will be animated, if false, it will be static. + Returns: + A filepath to the output video in mp4 format. + """ + import matplotlib.pyplot as plt + from matplotlib.animation import FuncAnimation + + if isinstance(audio, str): + audio_file = audio + audio = gr.processing_utils.audio_from_file(audio) + else: + tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False, prefix = name) + gr.processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name, format="wav") + audio_file = tmp_wav.name + + if not os.path.isfile(audio_file): + raise ValueError("Audio file not found.") + + ffmpeg = shutil.which("ffmpeg") + if not ffmpeg: + raise RuntimeError("ffmpeg not found.") + + duration = round(len(audio[1]) / audio[0], 4) + + # Helper methods to create waveform + def hex_to_rgb(hex_str): + return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)] + + def get_color_gradient(c1, c2, n): + if n < 1: + raise ValueError("Must have at least one stop in gradient") + c1_rgb = np.array(hex_to_rgb(c1)) / 255 + c2_rgb = np.array(hex_to_rgb(c2)) / 255 + mix_pcts = [x / (n - 1) for x in range(n)] + rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts] + return [ + "#" + "".join(f"{int(round(val * 255)):02x}" for val in item) + for item in rgb_colors + ] + + # Reshape audio to have a fixed number of bars + samples = audio[1] + if len(samples.shape) > 1: + samples = np.mean(samples, 1) + bins_to_pad = bar_count - (len(samples) % bar_count) + samples = np.pad(samples, [(0, bins_to_pad)]) + samples = np.reshape(samples, (bar_count, -1)) + samples = np.abs(samples) + samples = np.max(samples, 1) + + with MatplotlibBackendMananger(): + plt.clf() + # Plot waveform + color = ( + bars_color + if isinstance(bars_color, str) + else get_color_gradient(bars_color[0], bars_color[1], bar_count) + ) + + if animate: + fig = plt.figure(figsize=(5, 1), dpi=200, frameon=False) + fig.subplots_adjust(left=0, bottom=0, right=1, top=1) + plt.axis("off") + plt.margins(x=0) + + bar_alpha = fg_alpha if animate else 1.0 + barcollection = plt.bar( + np.arange(0, bar_count), + samples * 2, + bottom=(-1 * samples), + width=bar_width, + color=color, + alpha=bar_alpha, + ) + + tmp_img = NamedTemporaryFile(suffix=".png", delete=False, prefix = name) + + savefig_kwargs: dict[str, Any] = {"bbox_inches": "tight"} + if bg_image is not None: + savefig_kwargs["transparent"] = True + if animate: + savefig_kwargs["facecolor"] = "none" + else: + savefig_kwargs["facecolor"] = bg_color + plt.savefig(tmp_img.name, **savefig_kwargs) + + if not animate: + waveform_img = PIL.Image.open(tmp_img.name) + waveform_img = waveform_img.resize((1000, 400)) + + # Composite waveform with background image + if bg_image is not None: + waveform_array = np.array(waveform_img) + waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha + waveform_img = PIL.Image.fromarray(waveform_array) + + bg_img = PIL.Image.open(bg_image) + waveform_width, waveform_height = waveform_img.size + bg_width, bg_height = bg_img.size + if waveform_width != bg_width: + bg_img = bg_img.resize( + ( + waveform_width, + 2 * int(bg_height * waveform_width / bg_width / 2), + ) + ) + bg_width, bg_height = bg_img.size + composite_height = max(bg_height, waveform_height) + composite = PIL.Image.new( + "RGBA", (waveform_width, composite_height), "#FFFFFF" + ) + composite.paste(bg_img, (0, composite_height - bg_height)) + composite.paste( + waveform_img, (0, composite_height - waveform_height), waveform_img + ) + composite.save(tmp_img.name) + img_width, img_height = composite.size + else: + img_width, img_height = waveform_img.size + waveform_img.save(tmp_img.name) + else: + + def _animate(_): + for idx, b in enumerate(barcollection): + rand_height = np.random.uniform(0.8, 1.2) + b.set_height(samples[idx] * rand_height * 2) + b.set_y((-rand_height * samples)[idx]) + + frames = int(duration * 10) + anim = FuncAnimation( + fig, # type: ignore + _animate, # type: ignore + repeat=False, + blit=False, + frames=frames, + interval=100, + ) + anim.save( + tmp_img.name, + writer="pillow", + fps=10, + codec="png", + savefig_kwargs=savefig_kwargs, + ) + + # Convert waveform to video with ffmpeg + output_mp4 = NamedTemporaryFile(suffix=".mp4", delete=False, prefix = name) + + if animate and bg_image is not None: + ffmpeg_cmd = [ + ffmpeg, + "-loop", + "1", + "-i", + bg_image, + "-i", + tmp_img.name, + "-i", + audio_file, + "-filter_complex", + "[0:v]scale=w=trunc(iw/2)*2:h=trunc(ih/2)*2[bg];[1:v]format=rgba,colorchannelmixer=aa=1.0[ov];[bg][ov]overlay=(main_w-overlay_w*0.9)/2:main_h-overlay_h*0.9/2[output]", + "-t", + str(duration), + "-map", + "[output]", + "-map", + "2:a", + "-c:v", + "libx264", + "-c:a", + "aac", + "-shortest", + "-y", + output_mp4.name, + ] + elif animate and bg_image is None: + ffmpeg_cmd = [ + ffmpeg, + "-i", + tmp_img.name, + "-i", + audio_file, + "-filter_complex", + "[0:v][1:a]concat=n=1:v=1:a=1[v];[v]scale=1000:400,format=yuv420p[v_scaled]", + "-map", + "[v_scaled]", + "-map", + "1:a", + "-c:v", + "libx264", + "-c:a", + "aac", + "-shortest", + "-y", + output_mp4.name, + ] + else: + ffmpeg_cmd = [ + ffmpeg, + "-loop", + "1", + "-i", + tmp_img.name, + "-i", + audio_file, + "-vf", + f"color=c=#FFFFFF77:s={img_width}x{img_height}[bar];[0][bar]overlay=-w+(w/{duration})*t:H-h:shortest=1", # type: ignore + "-t", + str(duration), + "-y", + output_mp4.name, + ] + + subprocess.check_call(ffmpeg_cmd) + return output_mp4.name + +gr.make_waveform = make_waveform \ No newline at end of file diff --git a/modules/user_history.py b/modules/user_history.py new file mode 100644 index 0000000000000000000000000000000000000000..84b71ccca943badea8edacd96fd9ce3ef8a9e713 --- /dev/null +++ b/modules/user_history.py @@ -0,0 +1,598 @@ +""" +User History is a plugin that you can add to your Spaces to cache generated images for your users. + +Key features: +- 🤗 Sign in with Hugging Face +- Save generated image, video, audio and document files with their metadata: prompts, timestamp, hyper-parameters, etc. +- Export your history as zip. +- Delete your history to respect privacy. +- Compatible with Persistent Storage for long-term storage. +- Admin panel to check configuration and disk usage . + +Useful links: +- Demo: https://huggingface.co/spaces/Wauplin/gradio-user-history +- README: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/README.md +- Source file: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/user_history.py +- Discussions: https://huggingface.co/spaces/Wauplin/gradio-user-history/discussions +""" + +__version__ = "0.2.0" + +import json +import os +import shutil +import warnings +from datetime import datetime +from functools import cache +from pathlib import Path +from typing import Callable, Dict, List, Tuple, Any +from uuid import uuid4 + +import gradio as gr +import numpy as np +import requests +from filelock import FileLock +from PIL.Image import Image +import filetype +import wave +from mutagen.mp3 import MP3, EasyMP3 +import torchaudio +import subprocess + + +def setup(folder_path: str | Path | None = None) -> None: + user_history = _UserHistory() + user_history.folder_path = _resolve_folder_path(folder_path) + user_history.initialized = True + + +def render() -> None: + user_history = _UserHistory() + + # initialize with default config + if not user_history.initialized: + print("Initializing user history with default config. Use `user_history.setup(...)` to customize folder_path.") + setup() + + # Render user history tab + gr.Markdown( + "## Your past generations\n\nLog in to keep a gallery of your previous generations. Your history will be saved" + " and available on your next visit. Make sure to export your images from time to time as this gallery may be" + " deleted in the future." + ) + + if os.getenv("SYSTEM") == "spaces" and not os.path.exists("/data"): + gr.Markdown( + "**⚠️ Persistent storage is disabled, meaning your history will be lost if the Space gets restarted." + " Only the Space owner can setup a Persistent Storage. If you are not the Space owner, consider" + " duplicating this Space to set your own storage.⚠️**" + ) + + with gr.Row(): + gr.LoginButton(min_width=250) + #gr.LogoutButton(min_width=250) + refresh_button = gr.Button( + "Refresh", + icon="./assets/icon_refresh.png", + ) + export_button = gr.Button( + "Export", + icon="./assets/icon_download.png", + ) + delete_button = gr.Button( + "Delete history", + icon="./assets/icon_delete.png", + ) + + # "Export zip" row (hidden by default) + with gr.Row(): + export_file = gr.File(file_count="single", file_types=[".zip"], label="Exported history", visible=False) + + # "Config deletion" row (hidden by default) + with gr.Row(): + confirm_button = gr.Button("Confirm delete all history", variant="stop", visible=False) + cancel_button = gr.Button("Cancel", visible=False) + + # Gallery + gallery = gr.Gallery( + label="Past images", + show_label=True, + elem_id="gradio_user_history_gallery", + object_fit="cover", + columns=5, + height=600, + preview=False, + show_share_button=False, + show_download_button=True, + ) + gr.Markdown( + "User history is powered by" + " [Wauplin/gradio-user-history](https://huggingface.co/spaces/Wauplin/gradio-user-history). Integrate it to" + " your own Space in just a few lines of code!" + ) + gallery.attach_load_event(_fetch_user_history, every=None) + + # Interactions + refresh_button.click(fn=_fetch_user_history, inputs=[], outputs=[gallery], queue=False) + export_button.click(fn=_export_user_history, inputs=[], outputs=[export_file], queue=False) + + # Taken from https://github.com/gradio-app/gradio/issues/3324#issuecomment-1446382045 + delete_button.click( + lambda: [gr.update(visible=True), gr.update(visible=True)], + outputs=[confirm_button, cancel_button], + queue=False, + ) + cancel_button.click( + lambda: [gr.update(visible=False), gr.update(visible=False)], + outputs=[confirm_button, cancel_button], + queue=False, + ) + confirm_button.click(_delete_user_history).then( + lambda: [gr.update(visible=False), gr.update(visible=False)], + outputs=[confirm_button, cancel_button], + queue=False, + ) + + # Admin section (only shown locally or when logged in as Space owner) + _admin_section() + + +def save_image( + profile: gr.OAuthProfile | None, + image: Image | np.ndarray | str | Path, + label: str | None = None, + metadata: Dict | None = None, +): + # Ignore images from logged out users + if profile is None: + return + username = profile["preferred_username"] + + # Ignore images if user history not used + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn( + "User history is not set in Gradio demo. Saving image is ignored. You must use `user_history.render(...)`" + " first." + ) + return + + # Copy image to storage + image_path = _copy_image(image, dst_folder=user_history._user_images_path(username)) + + # Save new image + metadata + if metadata is None: + metadata = {} + if "datetime" not in metadata: + metadata["datetime"] = str(datetime.now()) + data = {"path": str(image_path), "label": label, "metadata": metadata} + with user_history._user_lock(username): + with user_history._user_jsonl_path(username).open("a") as f: + f.write(json.dumps(data) + "\n") + +def save_file( + profile: gr.OAuthProfile | None, + image: Image | np.ndarray | str | Path | None = None, + video: str | Path | None = None, + audio: str | Path | None = None, + document: str | Path | None = None, + label: str | None = None, + metadata: Dict | None = None, +): + # Ignore files from logged out users + if profile is None: + return + username = profile["preferred_username"] + + # Ignore files if user history not used + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn( + "User history is not set in Gradio demo. Saving files is ignored. You must use `user_history.render(...)`" + " first." + ) + return + + # Save new files + metadata + if metadata is None: + metadata = {} + if "datetime" not in metadata: + metadata["datetime"] = str(datetime.now()) + + # Copy image to storage + image_path = None + if image is not None: + image_path = _copy_image(image, dst_folder=user_history._user_images_path(username)) + image_path = _add_metadata(image_path, metadata) + + # Copy video to storage + if video is not None: + video_path = _copy_file(video, dst_folder=user_history._user_file_path(username, "videos")) + video_path = _add_metadata(video_path, metadata) + + # Copy audio to storage + if audio is not None: + audio_path = _copy_file(audio, dst_folder=user_history._user_file_path(username, "audios")) + audio_path = _add_metadata(audio_path, metadata) + + # Copy document to storage + if document is not None: + document_path = _copy_file(document, dst_folder=user_history._user_file_path(username, "documents")) + document_path = _add_metadata(document_path, metadata) + + # Save Json file + data = {"image_path": str(image_path), "video_path": str(video_path), "audio_path": str(audio_path), "document_path": str(document_path), "label": label, "metadata": metadata} + with user_history._user_lock(username): + with user_history._user_jsonl_path(username).open("a") as f: + f.write(json.dumps(data) + "\n") + + +############# +# Internals # +############# + + +class _UserHistory(object): + _instance = None + initialized: bool = False + folder_path: Path + + def __new__(cls): + # Using singleton pattern => we don't want to expose an object (more complex to use) but still want to keep + # state between `render` and `save_image` calls. + if cls._instance is None: + cls._instance = super(_UserHistory, cls).__new__(cls) + return cls._instance + + def _user_path(self, username: str) -> Path: + path = self.folder_path / username + path.mkdir(parents=True, exist_ok=True) + return path + + def _user_lock(self, username: str) -> FileLock: + """Ensure history is not corrupted if concurrent calls.""" + return FileLock(self.folder_path / f"{username}.lock") # lock outside of folder => better when exporting ZIP + + def _user_jsonl_path(self, username: str) -> Path: + return self._user_path(username) / "history.jsonl" + + def _user_images_path(self, username: str) -> Path: + path = self._user_path(username) / "images" + path.mkdir(parents=True, exist_ok=True) + return path + + def _user_file_path(self, username: str, filetype: str = "images") -> Path: + path = self._user_path(username) / filetype + path.mkdir(parents=True, exist_ok=True) + return path + + + +def _fetch_user_history(profile: gr.OAuthProfile | None) -> List[Tuple[str, str]]: + """Return saved history for that user, if it exists.""" + # Cannot load history for logged out users + if profile is None: + return [] + username = profile["preferred_username"] + + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.") + return [] + + with user_history._user_lock(username): + # No file => no history saved yet + jsonl_path = user_history._user_jsonl_path(username) + if not jsonl_path.is_file(): + return [] + + # Read history + images = [] + for line in jsonl_path.read_text().splitlines(): + data = json.loads(line) + images.append((data["path"], data["label"] or "")) + return list(reversed(images)) + + +def _export_user_history(profile: gr.OAuthProfile | None) -> Dict | None: + """Zip all history for that user, if it exists and return it as a downloadable file.""" + # Cannot load history for logged out users + if profile is None: + return None + username = profile["preferred_username"] + + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.") + return None + + # Zip history + with user_history._user_lock(username): + path = shutil.make_archive( + str(_archives_path() / f"history_{username}"), "zip", user_history._user_path(username) + ) + + return gr.update(visible=True, value=path) + + +def _delete_user_history(profile: gr.OAuthProfile | None) -> None: + """Delete all history for that user.""" + # Cannot load history for logged out users + if profile is None: + return + username = profile["preferred_username"] + + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.") + return + + with user_history._user_lock(username): + shutil.rmtree(user_history._user_path(username)) + + +#################### +# Internal helpers # +#################### + + +def _copy_image(image: Image | np.ndarray | str | Path, dst_folder: Path) -> Path: + try: + """Copy image to the images folder.""" + # Already a path => copy it + if isinstance(image, str): + image = Path(image) + if isinstance(image, Path): + dst = dst_folder / f"{uuid4().hex}_{Path(image).name}" # keep file ext + shutil.copyfile(image, dst) + return dst + + # Still a Python object => serialize it + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + if isinstance(image, Image): + dst = dst_folder / f"{Path(file).name}_{uuid4().hex}.png" + image.save(dst) + return dst + + raise ValueError(f"Unsupported image type: {type(image)}") + + except Exception as e: + print(f"An error occurred: {e}") + if not isinstance(dst, Path): + dst = Path(image) + return dst # Return the original file_location if an error occurs + +def _copy_file(file: Any | np.ndarray | str | Path, dst_folder: Path) -> Path: + try: + """Copy file to the appropriate folder.""" + # Already a path => copy it + if isinstance(file, str): + file = Path(file) + if isinstance(file, Path): + dst = dst_folder / f"{file.stem}_{uuid4().hex}{file.suffix}" # keep file ext + shutil.copyfile(file, dst) + return dst + + # Still a Python object => serialize it + if isinstance(file, np.ndarray): + file = Image.fromarray(file) + dst = dst_folder / f"{file.filename}_{uuid4().hex}{file.suffix}" + file.save(dst) + return dst + + # try other file types + kind = filetype.guess(file) + if kind is not None: + dst = dst_folder / f"{Path(file).stem}_{uuid4().hex}.{kind.extension}" + shutil.copyfile(file, dst) + return dst + raise ValueError(f"Unsupported file type: {type(file)}") + + except Exception as e: + print(f"An error occurred: {e}") + if not isinstance(dst, Path): + dst = Path(file) + return dst # Return the original file_location if an error occurs + + +def _add_metadata(file_location: Path, metadata: Dict[str, Any]) -> Path: + try: + file_type = file_location.suffix + valid_file_types = [".wav", ".mp3", ".mp4", ".png"] + if file_type not in valid_file_types: + raise ValueError("Invalid file type. Valid file types are .wav, .mp3, .mp4, .png") + + if file_type == ".wav": + # Open and process .wav file + with wave.open(file_location, 'rb') as wav_file: + # Get the current metadata + current_metadata = {key: value for key, value in wav_file.getparams()._asdict().items() if isinstance(value, (int, float))} + + # Update metadata + current_metadata.update(metadata) + + # Reopen the WAV file in write mode + with wave.open(file_location, 'wb') as wav_output_file: + # Set the new metadata + wav_output_file.setparams(wav_file.getparams()) + + # Save the WAV file (overwriting the previous version) + wav_output_file.close() + elif file_type == ".mp3": + # Open and process .mp3 file + audio = EasyMP3(file_location) + + # Add metadata to the file + for key, value in metadata.items(): + audio[key] = value + + # Save the MP3 file (overwriting the previous version) + audio.save() + elif file_type == ".mp4": + # Open and process .mp4 file + # Add metadata to the file + wav_file_location = file_location.with_suffix(".wav") + wave_exists = wav_file_location.exists() + if not wave_exists: + # Use torchaudio to create the WAV file if it doesn't exist + audio, sample_rate = torchaudio.load(file_location, normalize=True) + torchaudio.save(wav_file_location, audio, sample_rate, format='wav') + + # Use ffmpeg to add metadata to the video file + metadata_args = [f"{key}={value}" for key, value in metadata.items()] + ffmpeg_metadata = ":".join(metadata_args) + ffmpeg_cmd = f'ffmpeg -i "{file_location}" -i "{wav_file_location}" -map 0:v:0 -map 1:a:0 -c:v copy -c:a aac -metadata "{ffmpeg_metadata}" "{file_location}"' + subprocess.run(ffmpeg_cmd, shell=True, check=True) + + # Remove temporary WAV file + if not wave_exists: + wav_file_location.unlink() + elif file_type == ".png": + # Open and process .png file + image = Image.open(file_location) + exif_data = image.info.get("exif", {}) + exif_data.update(metadata) + # Add metadata to the file + image.save(file_location, exif=exif_data) + + return file_location # Return the path to the modified file + + except Exception as e: + print(f"An error occurred: {e}") + return file_location # Return the original file_location if an error occurs + +def _resolve_folder_path(folder_path: str | Path | None) -> Path: + if folder_path is not None: + return Path(folder_path).expanduser().resolve() + + if os.getenv("SYSTEM") == "spaces" and os.path.exists("/data"): # Persistent storage is enabled! + return Path("/data") / "_user_history" + + # Not in a Space or Persistent storage not enabled => local folder + return Path("_user_history").resolve() + + +def _archives_path() -> Path: + # Doesn't have to be on persistent storage as it's only used for download + path = Path(__file__).parent / "_user_history_exports" + path.mkdir(parents=True, exist_ok=True) + return path + + +################# +# Admin section # +################# + + +def _admin_section() -> None: + title = gr.Markdown() + title.attach_load_event(_display_if_admin(), every=None) + + +def _display_if_admin() -> Callable: + def _inner(profile: gr.OAuthProfile | None) -> str: + if profile is None: + return "" + if profile["preferred_username"] in _fetch_admins(): + return _admin_content() + return "" + + return _inner + + +def _admin_content() -> str: + return f""" +## Admin section + +Running on **{os.getenv("SYSTEM", "local")}** (id: {os.getenv("SPACE_ID")}). {_get_msg_is_persistent_storage_enabled()} + +Admins: {', '.join(_fetch_admins())} + +{_get_nb_users()} user(s), {_get_nb_images()} image(s) + +### Configuration + +History folder: *{_UserHistory().folder_path}* + +Exports folder: *{_archives_path()}* + +### Disk usage + +{_disk_space_warning_message()} +""" + + +def _get_nb_users() -> int: + user_history = _UserHistory() + if not user_history.initialized: + return 0 + if user_history.folder_path is not None and user_history.folder_path.exists(): + return len([path for path in user_history.folder_path.iterdir() if path.is_dir()]) + return 0 + + +def _get_nb_images() -> int: + user_history = _UserHistory() + if not user_history.initialized: + return 0 + if user_history.folder_path is not None and user_history.folder_path.exists(): + return len([path for path in user_history.folder_path.glob("*/images/*")]) + return 0 + + +def _get_msg_is_persistent_storage_enabled() -> str: + if os.getenv("SYSTEM") == "spaces": + if os.path.exists("/data"): + return "Persistent storage is enabled." + else: + return ( + "Persistent storage is not enabled. This means that user histories will be deleted when the Space is" + " restarted. Consider adding a Persistent Storage in your Space settings." + ) + return "" + + +def _disk_space_warning_message() -> str: + user_history = _UserHistory() + if not user_history.initialized: + return "" + + message = "" + if user_history.folder_path is not None: + total, used, _ = _get_disk_usage(user_history.folder_path) + message += f"History folder: **{used / 1e9 :.0f}/{total / 1e9 :.0f}GB** used ({100*used/total :.0f}%)." + + total, used, _ = _get_disk_usage(_archives_path()) + message += f"\n\nExports folder: **{used / 1e9 :.0f}/{total / 1e9 :.0f}GB** used ({100*used/total :.0f}%)." + + return f"{message.strip()}" + + +def _get_disk_usage(path: Path) -> Tuple[int, int, int]: + for path in [path] + list(path.parents): # first check target_dir, then each parents one by one + try: + return shutil.disk_usage(path) + except OSError: # if doesn't exist or can't read => fail silently and try parent one + pass + return 0, 0, 0 + + +@cache +def _fetch_admins() -> List[str]: + # Running locally => fake user is admin + if os.getenv("SYSTEM") != "spaces": + return ["FakeGradioUser"] + + # Running in Space but no space_id => ??? + space_id = os.getenv("SPACE_ID") + if space_id is None: + return ["Unknown"] + + # Running in Space => try to fetch organization members + # Otherwise, it's not an organization => namespace is the user + namespace = space_id.split("/")[0] + response = requests.get(f"https://huggingface.co/api/organizations/{namespace}/members") + if response.status_code == 200: + return sorted((member["user"] for member in response.json()), key=lambda x: x.lower()) + return [namespace] diff --git a/modules/version_info.py b/modules/version_info.py new file mode 100644 index 0000000000000000000000000000000000000000..e345e4638b3f31cf575a3ee9f9f27a4651b747a3 --- /dev/null +++ b/modules/version_info.py @@ -0,0 +1,123 @@ +# modules/version_info.py + +from audiocraft import __version__ as audiocraft_version +import subprocess +import os +import sys +import gc +import gradio as gr + +git = os.environ.get('GIT', "git") + +def commit_hash(): + try: + return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip() + except Exception: + return "" + +def get_xformers_version(): + try: + import xformers + return xformers.__version__ + except Exception: + return "" +def get_transformers_version(): + try: + import transformers + return transformers.__version__ + except Exception: + return "" + +def get_accelerate_version(): + try: + import accelerate + return accelerate.__version__ + except Exception: + return "" +def get_safetensors_version(): + try: + import safetensors + return safetensors.__version__ + except Exception: + return "" +def get_diffusers_version(): + try: + import diffusers + return diffusers.__version__ + except Exception: + return "" + +def get_torch_info(): + from torch import __version__ as torch_version_, version, cuda, backends + device_type = initialize_cuda() + if device_type == "cuda": + try: + info = [torch_version_, f"CUDA Version:{version.cuda}", f"Available:{cuda.is_available()}", f"flash attention enabled: {backends.cuda.flash_sdp_enabled()}", f"Capabilities: {cuda.get_device_capability(0)}", f"Device Name: {cuda.get_device_name(0)}", f"Device Count: {cuda.device_count()}"] + del torch_version_, version, cuda, backends + return info + except Exception: + del torch_version_, version, cuda, backends + return "" + else: + return "Not Recognized" + +def release_torch_resources(): + from torch import cuda + # Clear the CUDA cache + cuda.empty_cache() + cuda.ipc_collect() + # Delete any objects that are using GPU memory + #for obj in gc.get_objects(): + # if is_tensor(obj) or (hasattr(obj, 'data') and is_tensor(obj.data)): + # del obj + # Run garbage collection + del cuda + gc.collect() + + +def initialize_cuda(): + from torch import cuda, version + if cuda.is_available(): + device = cuda.device("cuda") + print(f"CUDA is available. Using device: {cuda.get_device_name(0)} with CUDA version: {version.cuda}") + result = "cuda" + else: + print("CUDA is not available. Using CPU.") + result = "cpu" + return result + +def versions_html(): + from torch import __version__ as torch_version_ + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = commit_hash() + + # Define the Toggle Dark Mode link with JavaScript + toggle_dark_link = ''' + + Toggle Dark Mode + + ''' + + v_html = f""" + version: " else commit}" target="_blank">{"huggingface" if commit == "" else commit} +  •  + autocraft: {audiocraft_version} +  •  + python: {python_version} +  •  + torch: {torch_version_} +  •  + xformers: {get_xformers_version()} +  •  + transformers: {get_transformers_version()} +  •  + safetensors: {get_safetensors_version()} +  •  + gradio: {gr.__version__} +  •  + {toggle_dark_link} +
+ Full GPU Info:{get_torch_info()} + """ + del torch_version_ + return v_html \ No newline at end of file diff --git a/pre-requirements.txt b/pre-requirements.txt index c136debd999692eebebd51fb83ba54487a15ee2d..c923cbb5540c90da0f504c00ffb9246f44c03576 100644 --- a/pre-requirements.txt +++ b/pre-requirements.txt @@ -1 +1 @@ -pip>=23.3 \ No newline at end of file +pip>=24.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e2bd9caf449122e5ea16ca2109a6c7e5ecad79a8..5c614543499e46270f1a352a9fe90b02b2f07074 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,38 @@ # please make sure you have already a pytorch install that is cuda enabled! -av +av==11.0.0 einops flashy>=0.0.1 hydra-core>=1.1 hydra_colorlog -julius -num2words -numpy -sentencepiece -spacy==3.5.2 -torch==2.0.1 -torchaudio==2.0.2 +torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 +torchaudio>=2.0.0,<2.6.2 --extra-index-url https://download.pytorch.org/whl/cu124 soundfile huggingface_hub tqdm -transformers>=4.31.0 -xformers>=0.0.22 +transformers>=4.48.0 # need Encodec there. +xformers>=0.0.23 --index-url https://download.pytorch.org/whl/cu124 demucs librosa -gradio==3.38.00 -pillow \ No newline at end of file +soundfile +gradio==5.23.3 +gradio[oauth] +pillow +torchmetrics +encodec +protobuf>=3.20.1 +filetype +wave +mutagen +fastapi>=0.88.0 +pydantic +typer +torchvision>=0.21.0 --extra-index-url https://download.pytorch.org/whl/cu124 +#torchtext +pesq +pystoi +julius +spacy==3.7.6 +sentencepiece +num2words +numpy<1.26.4 +matplotlib \ No newline at end of file diff --git a/style_20250331.css b/style_20250331.css new file mode 100644 index 0000000000000000000000000000000000000000..8280063fdffe544007f82bb3b62a9673f784ad55 --- /dev/null +++ b/style_20250331.css @@ -0,0 +1,215 @@ +.interface-wrapper { + max-width: 1024px; + margin: 0 auto; +} + +.centered { + margin: 0 auto; + display: block; + text-align:center; +} + +.solid { + opacity: 1.0 !important; + height: auto !important; +} + +.intro { + font-size: 1.2em !important; + font-weight: bold; + text-align: center; + background-color: rgba(242, 218, 163, 0.62); +} + +.dark .gradio-container.gradio-container-5-23-3 .contain .intro .prose { + background-color: rgba(41, 18, 5, 0.38) !important; +} +.toast-body.info { + background-color: rgba(242, 218, 163, 0.75); +} +.dark .toast-body.info { + background-color: rgba(128, 128, 128, 0.75); +} + +.small { + font-size: smaller !important; + text-align: center; +} + +.imgcontainer img { + object-fit: contain !important; +} + +#examples { + font-weight: bolder; +} + +--background-fill-primary: #FBCE50 !important; +#col-container { + max-width: 1024px; + margin-left: auto; + margin-right: auto; +} + +a { + text-decoration-line: underline; + font-weight: 600; +} + +#btn-generate { + background-image: linear-gradient(to right bottom, rgb(157, 255, 157), rgb(229, 255, 235)); + color: var(--primary-800); +} + + #btn-generate:hover { + background-image: linear-gradient(to right bottom, rgb(229, 255, 229), rgb(255, 255, 255)); + } + + #btn-generate:active { + background-image: linear-gradient(to right bottom, rgb(229, 255, 235), rgb(157, 255, 157)); + } + +#versions { + margin-top: 1em; + width: 100%; + text-align: center; +} + +.small-btn { + max-width: 75px; +} + +#gallery .thumbnails, #lora_gallery .thumbnails { + flex-direction: column !important; + display: inline-flex !important; + flex-wrap: wrap !important; + position: relative !important; +} + +#gallery caption.caption, #lora_gallery caption.caption { + flex-direction: row !important; + display: inline-flex !important; + flex-wrap: wrap; + white-space: unset !important; +} + +#gallery .image-button img.with-caption, #lora_gallery .image-button img.with-caption { + object-fit: cover !important; + object-position: center !important; +} + +#gallery button.preview, #lora_gallery button.preview { + position: relative !important; +} + +.gradio-container::before { + content: ' '; + display: block; + position: absolute; + left: 0; + top: 0; + width: 100%; + height: 100%; + opacity: 0.5; + background-image: url('gradio_api/file=./assets/Vermilion-Musical-Notes-Typography-No-Background.svg'); + background-repeat: no-repeat; + background-position: 50% 25%; + /*background-color: rgba(0,0,0,0.5);*/ + background-size: 45vh; + overflow: hidden; +} + +.gradio-container::after { + content: ''; + position: absolute; + top: 0; + left: -60%; /* Start off-screen */ + width: 30%; + height: 100%; + background: linear-gradient( 120deg, rgba(255, 255, 255, 0) 10%, rgba(255, 255, 255, 0.60) 50%, rgba(255, 255, 255, 0) 90% ); + animation: shine 30s infinite; +} + +#component-0, #component-1 { + opacity: 0.9; +} + +#excluded_colors { + width: 95%; + margin: 0 auto; + font-size: smaller; +} + +@media only screen and (min-width: 1920px) { + .gradio-container, .gradio-container::before { + max-width: 1920px !important; + } +} + +.sidebar .toggle-button::before { + content: 'Sketch Pad'; + font-weight: bold; + transform: rotate(180deg); + margin-right: -120px; + width: 120px; + background-color: rgba(242, 218, 163, 0.62); +} +.dark .sidebar .toggle-button::before { + background-color: rgba(41, 18, 5, 0.38) !important; +} + .sidebar.open .toggle-button::before { + content: ''; +} + +#sketchpd, #filters, #image_gen, #accordian_3d { + outline-color: #bbf7d0; + outline-style:solid; + outline-width: 1px; + outline-offset: 1px; + padding: 2px; + border-radius:6px; +} +.outline-important { + outline-color: var(--accordion-text-color); + outline-style: solid; + outline-width: 2px; + outline-offset: 2px; + padding: 2px; + border-radius: 6px; +} +.selected.svelte-1tcem6n.svelte-1tcem6n { + font-size: large; + font-weight: bold; + color: var(--body-text-color); +} +.tab-wrapper.svelte-1tcem6n.svelte-1tcem6n { + height: var(--size-12); + padding-bottom: var(--size-1); + text-align: center; + background-blend-mode: multiply; + border-radius: var(--block-radius); + background-color: var(--block-background-fill); + + outline-color: var(--accordion-text-color); + outline-style: solid; + outline-width: 2px; + outline-offset: 2px; + padding: 2px; + border-radius: 6px; +} + + + +@keyframes shine { + 0% { + left: -100%; + } + + 20% { + left: 100%; + } + + 100% { + left: 125%; + } +} \ No newline at end of file diff --git a/user_history.py b/user_history.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a9120f8cb4ed5eff075f08981fb8cf7c102527 --- /dev/null +++ b/user_history.py @@ -0,0 +1,598 @@ +""" +User History is a plugin that you can add to your Spaces to cache generated images for your users. + +Key features: +- 🤗 Sign in with Hugging Face +- Save generated image, video, audio and document files with their metadata: prompts, timestamp, hyper-parameters, etc. +- Export your history as zip. +- Delete your history to respect privacy. +- Compatible with Persistent Storage for long-term storage. +- Admin panel to check configuration and disk usage . + +Useful links: +- Demo: https://huggingface.co/spaces/Wauplin/gradio-user-history +- README: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/README.md +- Source file: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/user_history.py +- Discussions: https://huggingface.co/spaces/Wauplin/gradio-user-history/discussions +""" + +__version__ = "0.2.0" + +import json +import os +import shutil +import warnings +from datetime import datetime +from functools import cache +from pathlib import Path +from typing import Callable, Dict, List, Tuple +from uuid import uuid4 + +import gradio as gr +import numpy as np +import requests +from filelock import FileLock +from PIL.Image import Image +import filetype +import wave +from mutagen.mp3 import MP3, EasyMP3 +import torchaudio +import subprocess + + +def setup(folder_path: str | Path | None = None) -> None: + user_history = _UserHistory() + user_history.folder_path = _resolve_folder_path(folder_path) + user_history.initialized = True + + +def render() -> None: + user_history = _UserHistory() + + # initialize with default config + if not user_history.initialized: + print("Initializing user history with default config. Use `user_history.setup(...)` to customize folder_path.") + setup() + + # Render user history tab + gr.Markdown( + "## Your past generations\n\nLog in to keep a gallery of your previous generations. Your history will be saved" + " and available on your next visit. Make sure to export your images from time to time as this gallery may be" + " deleted in the future." + ) + + if os.getenv("SYSTEM") == "spaces" and not os.path.exists("/data"): + gr.Markdown( + "**⚠️ Persistent storage is disabled, meaning your history will be lost if the Space gets restarted." + " Only the Space owner can setup a Persistent Storage. If you are not the Space owner, consider" + " duplicating this Space to set your own storage.⚠️**" + ) + + with gr.Row(): + gr.LoginButton(min_width=250) + #gr.LogoutButton(min_width=250) + refresh_button = gr.Button( + "Refresh", + icon="./assets/icon_refresh.png", + ) + export_button = gr.Button( + "Export", + icon="./assets/icon_download.png", + ) + delete_button = gr.Button( + "Delete history", + icon="./assets/icon_delete.png", + ) + + # "Export zip" row (hidden by default) + with gr.Row(): + export_file = gr.File(file_count="single", file_types=[".zip"], label="Exported history", visible=False) + + # "Config deletion" row (hidden by default) + with gr.Row(): + confirm_button = gr.Button("Confirm delete all history", variant="stop", visible=False) + cancel_button = gr.Button("Cancel", visible=False) + + # Gallery + gallery = gr.Gallery( + label="Past images", + show_label=True, + elem_id="gradio_user_history_gallery", + object_fit="contain", + columns=5, + height=600, + preview=False, + show_share_button=False, + show_download_button=False, + ) + gr.Markdown( + "User history is powered by" + " [Wauplin/gradio-user-history](https://huggingface.co/spaces/Wauplin/gradio-user-history). Integrate it to" + " your own Space in just a few lines of code!" + ) + gallery.attach_load_event(_fetch_user_history, every=None) + + # Interactions + refresh_button.click(fn=_fetch_user_history, inputs=[], outputs=[gallery], queue=False) + export_button.click(fn=_export_user_history, inputs=[], outputs=[export_file], queue=False) + + # Taken from https://github.com/gradio-app/gradio/issues/3324#issuecomment-1446382045 + delete_button.click( + lambda: [gr.update(visible=True), gr.update(visible=True)], + outputs=[confirm_button, cancel_button], + queue=False, + ) + cancel_button.click( + lambda: [gr.update(visible=False), gr.update(visible=False)], + outputs=[confirm_button, cancel_button], + queue=False, + ) + confirm_button.click(_delete_user_history).then( + lambda: [gr.update(visible=False), gr.update(visible=False)], + outputs=[confirm_button, cancel_button], + queue=False, + ) + + # Admin section (only shown locally or when logged in as Space owner) + _admin_section() + + +def save_image( + profile: gr.OAuthProfile | None, + image: Image | np.ndarray | str | Path, + label: str | None = None, + metadata: Dict | None = None, +): + # Ignore images from logged out users + if profile is None: + return + username = profile["preferred_username"] + + # Ignore images if user history not used + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn( + "User history is not set in Gradio demo. Saving image is ignored. You must use `user_history.render(...)`" + " first." + ) + return + + # Copy image to storage + image_path = _copy_image(image, dst_folder=user_history._user_images_path(username)) + + # Save new image + metadata + if metadata is None: + metadata = {} + if "datetime" not in metadata: + metadata["datetime"] = str(datetime.now()) + data = {"path": str(image_path), "label": label, "metadata": metadata} + with user_history._user_lock(username): + with user_history._user_jsonl_path(username).open("a") as f: + f.write(json.dumps(data) + "\n") + +def save_file( + profile: gr.OAuthProfile | None, + image: Image | np.ndarray | str | Path | None = None, + video: str | Path | None = None, + audio: str | Path | None = None, + document: str | Path | None = None, + label: str | None = None, + metadata: Dict | None = None, +): + # Ignore files from logged out users + if profile is None: + return + username = profile["preferred_username"] + + # Ignore files if user history not used + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn( + "User history is not set in Gradio demo. Saving files is ignored. You must use `user_history.render(...)`" + " first." + ) + return + + # Save new files + metadata + if metadata is None: + metadata = {} + if "datetime" not in metadata: + metadata["datetime"] = str(datetime.now()) + + # Copy image to storage + image_path = None + if image is not None: + image_path = _copy_image(image, dst_folder=user_history._user_images_path(username)) + image_path = _add_metadata(image_path, metadata) + + # Copy video to storage + if video is not None: + video_path = _copy_file(video, dst_folder=user_history._user_file_path(username, "videos")) + video_path = _add_metadata(video_path, metadata) + + # Copy audio to storage + if audio is not None: + audio_path = _copy_file(audio, dst_folder=user_history._user_file_path(username, "audios")) + audio_path = _add_metadata(audio_path, metadata) + + # Copy document to storage + if document is not None: + document_path = _copy_file(document, dst_folder=user_history._user_file_path(username, "documents")) + document_path = _add_metadata(document_path, metadata) + + # Save Json file + data = {"image_path": str(image_path), "video_path": str(video_path), "audio_path": str(audio_path), "document_path": str(document_path), "label": label, "metadata": metadata} + with user_history._user_lock(username): + with user_history._user_jsonl_path(username).open("a") as f: + f.write(json.dumps(data) + "\n") + + +############# +# Internals # +############# + + +class _UserHistory(object): + _instance = None + initialized: bool = False + folder_path: Path + + def __new__(cls): + # Using singleton pattern => we don't want to expose an object (more complex to use) but still want to keep + # state between `render` and `save_image` calls. + if cls._instance is None: + cls._instance = super(_UserHistory, cls).__new__(cls) + return cls._instance + + def _user_path(self, username: str) -> Path: + path = self.folder_path / username + path.mkdir(parents=True, exist_ok=True) + return path + + def _user_lock(self, username: str) -> FileLock: + """Ensure history is not corrupted if concurrent calls.""" + return FileLock(self.folder_path / f"{username}.lock") # lock outside of folder => better when exporting ZIP + + def _user_jsonl_path(self, username: str) -> Path: + return self._user_path(username) / "history.jsonl" + + def _user_images_path(self, username: str) -> Path: + path = self._user_path(username) / "images" + path.mkdir(parents=True, exist_ok=True) + return path + + def _user_file_path(self, username: str, filetype: str = "images") -> Path: + path = self._user_path(username) / filetype + path.mkdir(parents=True, exist_ok=True) + return path + + + +def _fetch_user_history(profile: gr.OAuthProfile | None) -> List[Tuple[str, str]]: + """Return saved history for that user, if it exists.""" + # Cannot load history for logged out users + if profile is None: + return [] + username = profile["preferred_username"] + + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.") + return [] + + with user_history._user_lock(username): + # No file => no history saved yet + jsonl_path = user_history._user_jsonl_path(username) + if not jsonl_path.is_file(): + return [] + + # Read history + images = [] + for line in jsonl_path.read_text().splitlines(): + data = json.loads(line) + images.append((data["path"], data["label"] or "")) + return list(reversed(images)) + + +def _export_user_history(profile: gr.OAuthProfile | None) -> Dict | None: + """Zip all history for that user, if it exists and return it as a downloadable file.""" + # Cannot load history for logged out users + if profile is None: + return None + username = profile["preferred_username"] + + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.") + return None + + # Zip history + with user_history._user_lock(username): + path = shutil.make_archive( + str(_archives_path() / f"history_{username}"), "zip", user_history._user_path(username) + ) + + return gr.update(visible=True, value=path) + + +def _delete_user_history(profile: gr.OAuthProfile | None) -> None: + """Delete all history for that user.""" + # Cannot load history for logged out users + if profile is None: + return + username = profile["preferred_username"] + + user_history = _UserHistory() + if not user_history.initialized: + warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.") + return + + with user_history._user_lock(username): + shutil.rmtree(user_history._user_path(username)) + + +#################### +# Internal helpers # +#################### + + +def _copy_image(image: Image | np.ndarray | str | Path, dst_folder: Path) -> Path: + try: + """Copy image to the images folder.""" + # Already a path => copy it + if isinstance(image, str): + image = Path(image) + if isinstance(image, Path): + dst = dst_folder / f"{uuid4().hex}_{Path(image).name}" # keep file ext + shutil.copyfile(image, dst) + return dst + + # Still a Python object => serialize it + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + if isinstance(image, Image): + dst = dst_folder / f"Path(file).name}_{uuid4().hex}.png" + image.save(dst) + return dst + + raise ValueError(f"Unsupported image type: {type(image)}") + + except Exception as e: + print(f"An error occurred: {e}") + if not isinstance(dst, Path): + dst = Path(image) + return dst # Return the original file_location if an error occurs + +def _copy_file(file: any | np.ndarray | str | Path, dst_folder: Path) -> Path: + try: + """Copy file to the appropriate folder.""" + # Already a path => copy it + if isinstance(file, str): + file = Path(file) + if isinstance(file, Path): + dst = dst_folder / f"{file.stem}_{uuid4().hex}{file.suffix}" # keep file ext + shutil.copyfile(file, dst) + return dst + + # Still a Python object => serialize it + if isinstance(file, np.ndarray): + file = Image.fromarray(file) + dst = dst_folder / f"{file.filename}_{uuid4().hex}{file.suffix}" + file.save(dst) + return dst + + # try other file types + kind = filetype.guess(file) + if kind is not None: + dst = dst_folder / f"{Path(file).stem}_{uuid4().hex}.{kind.extension}" + shutil.copyfile(file, dst) + return dst + raise ValueError(f"Unsupported file type: {type(file)}") + + except Exception as e: + print(f"An error occurred: {e}") + if not isinstance(dst, Path): + dst = Path(file) + return dst # Return the original file_location if an error occurs + + +def _add_metadata(file_location: Path, metadata: Dict[str, Any]) -> Path: + try: + file_type = file_location.suffix + valid_file_types = [".wav", ".mp3", ".mp4", ".png"] + if file_type not in valid_file_types: + raise ValueError("Invalid file type. Valid file types are .wav, .mp3, .mp4, .png") + + if file_type == ".wav": + # Open and process .wav file + with wave.open(file_location, 'rb') as wav_file: + # Get the current metadata + current_metadata = {key: value for key, value in wav_file.getparams()._asdict().items() if isinstance(value, (int, float))} + + # Update metadata + current_metadata.update(metadata) + + # Reopen the WAV file in write mode + with wave.open(file_location, 'wb') as wav_output_file: + # Set the new metadata + wav_output_file.setparams(wav_file.getparams()) + + # Save the WAV file (overwriting the previous version) + wav_output_file.close() + elif file_type == ".mp3": + # Open and process .mp3 file + audio = EasyMP3(file_location) + + # Add metadata to the file + for key, value in metadata.items(): + audio[key] = value + + # Save the MP3 file (overwriting the previous version) + audio.save() + elif file_type == ".mp4": + # Open and process .mp4 file + # Add metadata to the file + wav_file_location = file_location.with_suffix(".wav") + wave_exists = wav_file_location.exists() + if not wave_exists: + # Use torchaudio to create the WAV file if it doesn't exist + audio, sample_rate = torchaudio.load(file_location, normalize=True) + torchaudio.save(wav_file_location, audio, sample_rate, format='wav') + + # Use ffmpeg to add metadata to the video file + metadata_args = [f"{key}={value}" for key, value in metadata.items()] + ffmpeg_metadata = ":".join(metadata_args) + ffmpeg_cmd = f'ffmpeg -i "{file_location}" -i "{wav_file_location}" -map 0:v:0 -map 1:a:0 -c:v copy -c:a aac -metadata "{ffmpeg_metadata}" "{file_location}"' + subprocess.run(ffmpeg_cmd, shell=True, check=True) + + # Remove temporary WAV file + if not wave_exists: + wav_file_location.unlink() + elif file_type == ".png": + # Open and process .png file + image = Image.open(file_location) + exif_data = image.info.get("exif", {}) + exif_data.update(metadata) + # Add metadata to the file + image.save(file_location, exif=exif_data) + + return file_location # Return the path to the modified file + + except Exception as e: + print(f"An error occurred: {e}") + return file_location # Return the original file_location if an error occurs + +def _resolve_folder_path(folder_path: str | Path | None) -> Path: + if folder_path is not None: + return Path(folder_path).expanduser().resolve() + + if os.getenv("SYSTEM") == "spaces" and os.path.exists("/data"): # Persistent storage is enabled! + return Path("/data") / "_user_history" + + # Not in a Space or Persistent storage not enabled => local folder + return Path("_user_history").resolve() + + +def _archives_path() -> Path: + # Doesn't have to be on persistent storage as it's only used for download + path = Path(__file__).parent / "_user_history_exports" + path.mkdir(parents=True, exist_ok=True) + return path + + +################# +# Admin section # +################# + + +def _admin_section() -> None: + title = gr.Markdown() + title.attach_load_event(_display_if_admin(), every=None) + + +def _display_if_admin() -> Callable: + def _inner(profile: gr.OAuthProfile | None) -> str: + if profile is None: + return "" + if profile["preferred_username"] in _fetch_admins(): + return _admin_content() + return "" + + return _inner + + +def _admin_content() -> str: + return f""" +## Admin section + +Running on **{os.getenv("SYSTEM", "local")}** (id: {os.getenv("SPACE_ID")}). {_get_msg_is_persistent_storage_enabled()} + +Admins: {', '.join(_fetch_admins())} + +{_get_nb_users()} user(s), {_get_nb_images()} image(s) + +### Configuration + +History folder: *{_UserHistory().folder_path}* + +Exports folder: *{_archives_path()}* + +### Disk usage + +{_disk_space_warning_message()} +""" + + +def _get_nb_users() -> int: + user_history = _UserHistory() + if not user_history.initialized: + return 0 + if user_history.folder_path is not None and user_history.folder_path.exists(): + return len([path for path in user_history.folder_path.iterdir() if path.is_dir()]) + return 0 + + +def _get_nb_images() -> int: + user_history = _UserHistory() + if not user_history.initialized: + return 0 + if user_history.folder_path is not None and user_history.folder_path.exists(): + return len([path for path in user_history.folder_path.glob("*/images/*")]) + return 0 + + +def _get_msg_is_persistent_storage_enabled() -> str: + if os.getenv("SYSTEM") == "spaces": + if os.path.exists("/data"): + return "Persistent storage is enabled." + else: + return ( + "Persistent storage is not enabled. This means that user histories will be deleted when the Space is" + " restarted. Consider adding a Persistent Storage in your Space settings." + ) + return "" + + +def _disk_space_warning_message() -> str: + user_history = _UserHistory() + if not user_history.initialized: + return "" + + message = "" + if user_history.folder_path is not None: + total, used, _ = _get_disk_usage(user_history.folder_path) + message += f"History folder: **{used / 1e9 :.0f}/{total / 1e9 :.0f}GB** used ({100*used/total :.0f}%)." + + total, used, _ = _get_disk_usage(_archives_path()) + message += f"\n\nExports folder: **{used / 1e9 :.0f}/{total / 1e9 :.0f}GB** used ({100*used/total :.0f}%)." + + return f"{message.strip()}" + + +def _get_disk_usage(path: Path) -> Tuple[int, int, int]: + for path in [path] + list(path.parents): # first check target_dir, then each parents one by one + try: + return shutil.disk_usage(path) + except OSError: # if doesn't exist or can't read => fail silently and try parent one + pass + return 0, 0, 0 + + +@cache +def _fetch_admins() -> List[str]: + # Running locally => fake user is admin + if os.getenv("SYSTEM") != "spaces": + return ["FakeGradioUser"] + + # Running in Space but no space_id => ??? + space_id = os.getenv("SPACE_ID") + if space_id is None: + return ["Unknown"] + + # Running in Space => try to fetch organization members + # Otherwise, it's not an organization => namespace is the user + namespace = space_id.split("/")[0] + response = requests.get(f"https://huggingface.co/api/organizations/{namespace}/members") + if response.status_code == 200: + return sorted((member["user"] for member in response.json()), key=lambda x: x.lower()) + return [namespace]