Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Merge branch 'main' into our_hf2
Browse files- CHANGELOG.md +11 -2
- MODEL_CARD.md +2 -2
- README.md +17 -5
- app.py +151 -85
- app_batched.py +3 -1
- audiocraft/__init__.py +1 -1
- audiocraft/data/audio.py +3 -1
- audiocraft/data/audio_utils.py +9 -4
- audiocraft/models/musicgen.py +2 -0
- audiocraft/modules/conditioners.py +6 -2
- requirements.txt +1 -0
    	
        CHANGELOG.md
    CHANGED
    
    | @@ -4,6 +4,15 @@ All notable changes to this project will be documented in this file. | |
| 4 |  | 
| 5 | 
             
            The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
         | 
| 6 |  | 
| 7 | 
            -
            ## [0.0. | 
| 8 |  | 
| 9 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 4 |  | 
| 5 | 
             
            The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
         | 
| 6 |  | 
| 7 | 
            +
            ## [0.0.2a] - TBD
         | 
| 8 |  | 
| 9 | 
            +
            Improved demo, fixed top p (thanks @jnordberg).
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            Compressor tanh on output to avoid clipping with some style (especially piano).
         | 
| 12 | 
            +
            Now repeating the conditioning periodically if it is too short.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            More options when launching Gradio app locally (thanks @ashleykleynhans).
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            ## [0.0.1] - 2023-06-09
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            Initial release, with model evaluation only.
         | 
    	
        MODEL_CARD.md
    CHANGED
    
    | @@ -52,7 +52,7 @@ The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/data | |
| 52 |  | 
| 53 | 
             
            ## Training datasets
         | 
| 54 |  | 
| 55 | 
            -
            The model was trained using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound),  [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing.
         | 
| 56 |  | 
| 57 | 
             
            ## Quantitative analysis
         | 
| 58 |  | 
| @@ -62,7 +62,7 @@ More information can be found in the paper [Simple and Controllable Music Genera | |
| 62 |  | 
| 63 | 
             
            **Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model.
         | 
| 64 |  | 
| 65 | 
            -
            **Mitigations:**  | 
| 66 |  | 
| 67 | 
             
            **Limitations:**
         | 
| 68 |  | 
|  | |
| 52 |  | 
| 53 | 
             
            ## Training datasets
         | 
| 54 |  | 
| 55 | 
            +
            The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound),  [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing.
         | 
| 56 |  | 
| 57 | 
             
            ## Quantitative analysis
         | 
| 58 |  | 
|  | |
| 62 |  | 
| 63 | 
             
            **Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model.
         | 
| 64 |  | 
| 65 | 
            +
            **Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs).
         | 
| 66 |  | 
| 67 | 
             
            **Limitations:**
         | 
| 68 |  | 
    	
        README.md
    CHANGED
    
    | @@ -24,12 +24,12 @@ Audiocraft is a PyTorch library for deep learning research on audio generation. | |
| 24 | 
             
            ## MusicGen
         | 
| 25 |  | 
| 26 | 
             
            Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive
         | 
| 27 | 
            -
            Transformer model trained over a 32kHz <a href="https://github.com/facebookresearch/encodec">EnCodec tokenizer</a> with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't  | 
| 28 | 
             
            all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict
         | 
| 29 | 
             
            them in parallel, thus having only 50 auto-regressive steps per second of audio.
         | 
| 30 | 
             
            Check out our [sample page][musicgen_samples] or test the available demo!
         | 
| 31 |  | 
| 32 | 
            -
            <a target="_blank" href="https://colab.research.google.com/drive/ | 
| 33 | 
             
              <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
         | 
| 34 | 
             
            </a>
         | 
| 35 | 
             
            <a target="_blank" href="https://huggingface.co/spaces/facebook/MusicGen">
         | 
| @@ -37,6 +37,8 @@ Check out our [sample page][musicgen_samples] or test the available demo! | |
| 37 | 
             
            </a>
         | 
| 38 | 
             
            <br>
         | 
| 39 |  | 
|  | |
|  | |
| 40 | 
             
            ## Installation
         | 
| 41 | 
             
            Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
         | 
| 42 |  | 
| @@ -51,7 +53,12 @@ pip install -e .  # or if you cloned the repo locally | |
| 51 | 
             
            ```
         | 
| 52 |  | 
| 53 | 
             
            ## Usage
         | 
| 54 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 55 |  | 
| 56 | 
             
            ## API
         | 
| 57 |  | 
| @@ -68,7 +75,7 @@ GPUs will be able to generate short sequences, or longer sequences with the `sma | |
| 68 | 
             
            **Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
         | 
| 69 | 
             
            You can install it with:
         | 
| 70 | 
             
            ```
         | 
| 71 | 
            -
            apt | 
| 72 | 
             
            ```
         | 
| 73 |  | 
| 74 | 
             
            See after a quick example for using the API.
         | 
| @@ -90,7 +97,7 @@ wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), s | |
| 90 |  | 
| 91 | 
             
            for idx, one_wav in enumerate(wav):
         | 
| 92 | 
             
                # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
         | 
| 93 | 
            -
                audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness")
         | 
| 94 | 
             
            ```
         | 
| 95 |  | 
| 96 |  | 
| @@ -105,6 +112,11 @@ See [the model card page](./MODEL_CARD.md). | |
| 105 | 
             
            Yes. We will soon release the training code for MusicGen and EnCodec.
         | 
| 106 |  | 
| 107 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 108 | 
             
            ## Citation
         | 
| 109 | 
             
            ```
         | 
| 110 | 
             
            @article{copet2023simple,
         | 
|  | |
| 24 | 
             
            ## MusicGen
         | 
| 25 |  | 
| 26 | 
             
            Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive
         | 
| 27 | 
            +
            Transformer model trained over a 32kHz <a href="https://github.com/facebookresearch/encodec">EnCodec tokenizer</a> with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require a self-supervised semantic representation, and it generates
         | 
| 28 | 
             
            all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict
         | 
| 29 | 
             
            them in parallel, thus having only 50 auto-regressive steps per second of audio.
         | 
| 30 | 
             
            Check out our [sample page][musicgen_samples] or test the available demo!
         | 
| 31 |  | 
| 32 | 
            +
            <a target="_blank" href="https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing">
         | 
| 33 | 
             
              <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
         | 
| 34 | 
             
            </a>
         | 
| 35 | 
             
            <a target="_blank" href="https://huggingface.co/spaces/facebook/MusicGen">
         | 
|  | |
| 37 | 
             
            </a>
         | 
| 38 | 
             
            <br>
         | 
| 39 |  | 
| 40 | 
            +
            We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
         | 
| 41 | 
            +
             | 
| 42 | 
             
            ## Installation
         | 
| 43 | 
             
            Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
         | 
| 44 |  | 
|  | |
| 53 | 
             
            ```
         | 
| 54 |  | 
| 55 | 
             
            ## Usage
         | 
| 56 | 
            +
            We offer a number of way to interact with MusicGen:
         | 
| 57 | 
            +
            1. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally, or use the provided [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing).
         | 
| 58 | 
            +
            2. You can use the gradio demo locally by running `python app.py`.
         | 
| 59 | 
            +
            3. A demo is also available on the [`facebook/MusicGen`  HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
         | 
| 60 | 
            +
            4. Finally, you can run the [Gradio demo with a Colab GPU](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing),
         | 
| 61 | 
            +
            as adapted from [@camenduru Colab](https://github.com/camenduru/MusicGen-colab).
         | 
| 62 |  | 
| 63 | 
             
            ## API
         | 
| 64 |  | 
|  | |
| 75 | 
             
            **Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
         | 
| 76 | 
             
            You can install it with:
         | 
| 77 | 
             
            ```
         | 
| 78 | 
            +
            apt-get install ffmpeg
         | 
| 79 | 
             
            ```
         | 
| 80 |  | 
| 81 | 
             
            See after a quick example for using the API.
         | 
|  | |
| 97 |  | 
| 98 | 
             
            for idx, one_wav in enumerate(wav):
         | 
| 99 | 
             
                # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
         | 
| 100 | 
            +
                audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
         | 
| 101 | 
             
            ```
         | 
| 102 |  | 
| 103 |  | 
|  | |
| 112 | 
             
            Yes. We will soon release the training code for MusicGen and EnCodec.
         | 
| 113 |  | 
| 114 |  | 
| 115 | 
            +
            #### I need help on Windows
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            @FurkanGozukara made a complete tutorial for [Audiocraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
             
            ## Citation
         | 
| 121 | 
             
            ```
         | 
| 122 | 
             
            @article{copet2023simple,
         | 
    	
        app.py
    CHANGED
    
    | @@ -7,14 +7,15 @@ LICENSE file in the root directory of this source tree. | |
| 7 | 
             
            """
         | 
| 8 |  | 
| 9 | 
             
            from tempfile import NamedTemporaryFile
         | 
|  | |
| 10 | 
             
            import torch
         | 
| 11 | 
             
            import gradio as gr
         | 
|  | |
| 12 | 
             
            from audiocraft.models import MusicGen
         | 
| 13 | 
            -
             | 
| 14 | 
             
            from audiocraft.data.audio import audio_write
         | 
| 15 |  | 
| 16 | 
            -
             | 
| 17 | 
             
            MODEL = None
         | 
|  | |
| 18 |  | 
| 19 |  | 
| 20 | 
             
            def load_model(version):
         | 
| @@ -56,95 +57,160 @@ def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef): | |
| 56 |  | 
| 57 | 
             
                output = output.detach().cpu().float()[0]
         | 
| 58 | 
             
                with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 59 | 
            -
                    audio_write( | 
|  | |
|  | |
| 60 | 
             
                    waveform_video = gr.make_waveform(file.name)
         | 
| 61 | 
             
                return waveform_video
         | 
| 62 |  | 
| 63 |  | 
| 64 | 
            -
             | 
| 65 | 
            -
                gr. | 
| 66 | 
            -
                     | 
| 67 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 68 |  | 
| 69 | 
            -
                     | 
| 70 | 
            -
                     | 
| 71 | 
            -
                     | 
| 72 | 
            -
                     | 
| 73 | 
            -
                     | 
| 74 | 
            -
                     | 
| 75 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 76 | 
             
                )
         | 
| 77 | 
            -
                 | 
| 78 | 
            -
                     | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
                        with gr.Row():
         | 
| 83 | 
            -
                            submit = gr.Button("Submit")
         | 
| 84 | 
            -
                        with gr.Row():
         | 
| 85 | 
            -
                            model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
         | 
| 86 | 
            -
                        with gr.Row():
         | 
| 87 | 
            -
                            duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
         | 
| 88 | 
            -
                        with gr.Row():
         | 
| 89 | 
            -
                            topk = gr.Number(label="Top-k", value=250, interactive=True)
         | 
| 90 | 
            -
                            topp = gr.Number(label="Top-p", value=0, interactive=True)
         | 
| 91 | 
            -
                            temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
         | 
| 92 | 
            -
                            cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
         | 
| 93 | 
            -
                    with gr.Column():
         | 
| 94 | 
            -
                        output = gr.Video(label="Generated Music")
         | 
| 95 | 
            -
                submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
         | 
| 96 | 
            -
                gr.Examples(
         | 
| 97 | 
            -
                    fn=predict,
         | 
| 98 | 
            -
                    examples=[
         | 
| 99 | 
            -
                        [
         | 
| 100 | 
            -
                            "An 80s driving pop song with heavy drums and synth pads in the background",
         | 
| 101 | 
            -
                            "./assets/bach.mp3",
         | 
| 102 | 
            -
                            "melody"
         | 
| 103 | 
            -
                        ],
         | 
| 104 | 
            -
                        [
         | 
| 105 | 
            -
                            "A cheerful country song with acoustic guitars",
         | 
| 106 | 
            -
                            "./assets/bolero_ravel.mp3",
         | 
| 107 | 
            -
                            "melody"
         | 
| 108 | 
            -
                        ],
         | 
| 109 | 
            -
                        [
         | 
| 110 | 
            -
                            "90s rock song with electric guitar and heavy drums",
         | 
| 111 | 
            -
                            None,
         | 
| 112 | 
            -
                            "medium"
         | 
| 113 | 
            -
                        ],
         | 
| 114 | 
            -
                        [
         | 
| 115 | 
            -
                            "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
         | 
| 116 | 
            -
                            "./assets/bach.mp3",
         | 
| 117 | 
            -
                            "melody"
         | 
| 118 | 
            -
                        ],
         | 
| 119 | 
            -
                        [
         | 
| 120 | 
            -
                            "lofi slow bpm electro chill with organic samples",
         | 
| 121 | 
            -
                            None,
         | 
| 122 | 
            -
                            "medium",
         | 
| 123 | 
            -
                        ],
         | 
| 124 | 
            -
                    ],
         | 
| 125 | 
            -
                    inputs=[text, melody, model],
         | 
| 126 | 
            -
                    outputs=[output]
         | 
| 127 | 
             
                )
         | 
| 128 | 
            -
                 | 
| 129 | 
            -
                     | 
| 130 | 
            -
                     | 
| 131 | 
            -
             | 
| 132 | 
            -
                     | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 135 | 
            -
                     | 
| 136 | 
            -
             | 
| 137 | 
            -
             | 
| 138 | 
            -
                     | 
| 139 | 
            -
                    4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
         | 
| 140 | 
            -
             | 
| 141 | 
            -
                    When using `melody`, ou can optionaly provide a reference audio from
         | 
| 142 | 
            -
                    which a broad melody will be extracted. The model will then try to follow both the description and melody provided.
         | 
| 143 | 
            -
             | 
| 144 | 
            -
                    You can also use your own GPU or a Google Colab by following the instructions on our repo.
         | 
| 145 | 
            -
                    See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
         | 
| 146 | 
            -
                    for more details.
         | 
| 147 | 
            -
                    """
         | 
| 148 | 
             
                )
         | 
| 149 |  | 
| 150 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 7 | 
             
            """
         | 
| 8 |  | 
| 9 | 
             
            from tempfile import NamedTemporaryFile
         | 
| 10 | 
            +
            import argparse
         | 
| 11 | 
             
            import torch
         | 
| 12 | 
             
            import gradio as gr
         | 
| 13 | 
            +
            import os
         | 
| 14 | 
             
            from audiocraft.models import MusicGen
         | 
|  | |
| 15 | 
             
            from audiocraft.data.audio import audio_write
         | 
| 16 |  | 
|  | |
| 17 | 
             
            MODEL = None
         | 
| 18 | 
            +
            IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ['SPACE_ID']
         | 
| 19 |  | 
| 20 |  | 
| 21 | 
             
            def load_model(version):
         | 
|  | |
| 57 |  | 
| 58 | 
             
                output = output.detach().cpu().float()[0]
         | 
| 59 | 
             
                with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 60 | 
            +
                    audio_write(
         | 
| 61 | 
            +
                        file.name, output, MODEL.sample_rate, strategy="loudness",
         | 
| 62 | 
            +
                        loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
         | 
| 63 | 
             
                    waveform_video = gr.make_waveform(file.name)
         | 
| 64 | 
             
                return waveform_video
         | 
| 65 |  | 
| 66 |  | 
| 67 | 
            +
            def ui(**kwargs):
         | 
| 68 | 
            +
                with gr.Blocks() as interface:
         | 
| 69 | 
            +
                    gr.Markdown(
         | 
| 70 | 
            +
                        """
         | 
| 71 | 
            +
                        # MusicGen
         | 
| 72 | 
            +
                        This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
         | 
| 73 | 
            +
                        presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
         | 
| 74 | 
            +
                        """
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
                    if IS_SHARED_SPACE:
         | 
| 77 | 
            +
                        gr.Markdown("""
         | 
| 78 | 
            +
                            ⚠ This Space doesn't work in this shared UI ⚠
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                            <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
         | 
| 81 | 
            +
                            <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
         | 
| 82 | 
            +
                            to use it privately, or use the <a href="https://huggingface.co/spaces/facebook/MusicGen">public demo</a>
         | 
| 83 | 
            +
                            """)
         | 
| 84 | 
            +
                    with gr.Row():
         | 
| 85 | 
            +
                        with gr.Column():
         | 
| 86 | 
            +
                            with gr.Row():
         | 
| 87 | 
            +
                                text = gr.Text(label="Input Text", interactive=True)
         | 
| 88 | 
            +
                                melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
         | 
| 89 | 
            +
                            with gr.Row():
         | 
| 90 | 
            +
                                submit = gr.Button("Submit")
         | 
| 91 | 
            +
                            with gr.Row():
         | 
| 92 | 
            +
                                model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
         | 
| 93 | 
            +
                            with gr.Row():
         | 
| 94 | 
            +
                                duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
         | 
| 95 | 
            +
                            with gr.Row():
         | 
| 96 | 
            +
                                topk = gr.Number(label="Top-k", value=250, interactive=True)
         | 
| 97 | 
            +
                                topp = gr.Number(label="Top-p", value=0, interactive=True)
         | 
| 98 | 
            +
                                temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
         | 
| 99 | 
            +
                                cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
         | 
| 100 | 
            +
                        with gr.Column():
         | 
| 101 | 
            +
                            output = gr.Video(label="Generated Music")
         | 
| 102 | 
            +
                    submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
         | 
| 103 | 
            +
                    gr.Examples(
         | 
| 104 | 
            +
                        fn=predict,
         | 
| 105 | 
            +
                        examples=[
         | 
| 106 | 
            +
                            [
         | 
| 107 | 
            +
                                "An 80s driving pop song with heavy drums and synth pads in the background",
         | 
| 108 | 
            +
                                "./assets/bach.mp3",
         | 
| 109 | 
            +
                                "melody"
         | 
| 110 | 
            +
                            ],
         | 
| 111 | 
            +
                            [
         | 
| 112 | 
            +
                                "A cheerful country song with acoustic guitars",
         | 
| 113 | 
            +
                                "./assets/bolero_ravel.mp3",
         | 
| 114 | 
            +
                                "melody"
         | 
| 115 | 
            +
                            ],
         | 
| 116 | 
            +
                            [
         | 
| 117 | 
            +
                                "90s rock song with electric guitar and heavy drums",
         | 
| 118 | 
            +
                                None,
         | 
| 119 | 
            +
                                "medium"
         | 
| 120 | 
            +
                            ],
         | 
| 121 | 
            +
                            [
         | 
| 122 | 
            +
                                "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
         | 
| 123 | 
            +
                                "./assets/bach.mp3",
         | 
| 124 | 
            +
                                "melody"
         | 
| 125 | 
            +
                            ],
         | 
| 126 | 
            +
                            [
         | 
| 127 | 
            +
                                "lofi slow bpm electro chill with organic samples",
         | 
| 128 | 
            +
                                None,
         | 
| 129 | 
            +
                                "medium",
         | 
| 130 | 
            +
                            ],
         | 
| 131 | 
            +
                        ],
         | 
| 132 | 
            +
                        inputs=[text, melody, model],
         | 
| 133 | 
            +
                        outputs=[output]
         | 
| 134 | 
            +
                    )
         | 
| 135 | 
            +
                    gr.Markdown(
         | 
| 136 | 
            +
                        """
         | 
| 137 | 
            +
                        ### More details
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        The model will generate a short music extract based on the description you provided.
         | 
| 140 | 
            +
                        You can generate up to 30 seconds of audio.
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                        We present 4 model variations:
         | 
| 143 | 
            +
                        1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
         | 
| 144 | 
            +
                        2. Small -- a 300M transformer decoder conditioned on text only.
         | 
| 145 | 
            +
                        3. Medium -- a 1.5B transformer decoder conditioned on text only.
         | 
| 146 | 
            +
                        4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                        When using `melody`, ou can optionaly provide a reference audio from
         | 
| 149 | 
            +
                        which a broad melody will be extracted. The model will then try to follow both the description and melody provided.
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                        You can also use your own GPU or a Google Colab by following the instructions on our repo.
         | 
| 152 | 
            +
                        See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
         | 
| 153 | 
            +
                        for more details.
         | 
| 154 | 
            +
                        """
         | 
| 155 | 
            +
                    )
         | 
| 156 |  | 
| 157 | 
            +
                    # Show the interface
         | 
| 158 | 
            +
                    launch_kwargs = {}
         | 
| 159 | 
            +
                    username = kwargs.get('username')
         | 
| 160 | 
            +
                    password = kwargs.get('password')
         | 
| 161 | 
            +
                    server_port = kwargs.get('server_port', 0)
         | 
| 162 | 
            +
                    inbrowser = kwargs.get('inbrowser', False)
         | 
| 163 | 
            +
                    share = kwargs.get('share', False)
         | 
| 164 | 
            +
                    server_name = kwargs.get('listen')
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    launch_kwargs['server_name'] = server_name
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    if username and password:
         | 
| 169 | 
            +
                        launch_kwargs['auth'] = (username, password)
         | 
| 170 | 
            +
                    if server_port > 0:
         | 
| 171 | 
            +
                        launch_kwargs['server_port'] = server_port
         | 
| 172 | 
            +
                    if inbrowser:
         | 
| 173 | 
            +
                        launch_kwargs['inbrowser'] = inbrowser
         | 
| 174 | 
            +
                    if share:
         | 
| 175 | 
            +
                        launch_kwargs['share'] = share
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    interface.queue().launch(**launch_kwargs, max_threads=1)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
             | 
| 180 | 
            +
            if __name__ == "__main__":
         | 
| 181 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 182 | 
            +
                parser.add_argument(
         | 
| 183 | 
            +
                    '--listen',
         | 
| 184 | 
            +
                    type=str,
         | 
| 185 | 
            +
                    default='127.0.0.1',
         | 
| 186 | 
            +
                    help='IP to listen on for connections to Gradio',
         | 
| 187 | 
             
                )
         | 
| 188 | 
            +
                parser.add_argument(
         | 
| 189 | 
            +
                    '--username', type=str, default='', help='Username for authentication'
         | 
| 190 | 
            +
                )
         | 
| 191 | 
            +
                parser.add_argument(
         | 
| 192 | 
            +
                    '--password', type=str, default='', help='Password for authentication'
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 193 | 
             
                )
         | 
| 194 | 
            +
                parser.add_argument(
         | 
| 195 | 
            +
                    '--server_port',
         | 
| 196 | 
            +
                    type=int,
         | 
| 197 | 
            +
                    default=0,
         | 
| 198 | 
            +
                    help='Port to run the server listener on',
         | 
| 199 | 
            +
                )
         | 
| 200 | 
            +
                parser.add_argument(
         | 
| 201 | 
            +
                    '--inbrowser', action='store_true', help='Open in browser'
         | 
| 202 | 
            +
                )
         | 
| 203 | 
            +
                parser.add_argument(
         | 
| 204 | 
            +
                    '--share', action='store_true', help='Share the gradio UI'
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 205 | 
             
                )
         | 
| 206 |  | 
| 207 | 
            +
                args = parser.parse_args()
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                ui(
         | 
| 210 | 
            +
                    username=args.username,
         | 
| 211 | 
            +
                    password=args.password,
         | 
| 212 | 
            +
                    inbrowser=args.inbrowser,
         | 
| 213 | 
            +
                    server_port=args.server_port,
         | 
| 214 | 
            +
                    share=args.share,
         | 
| 215 | 
            +
                    listen=args.listen
         | 
| 216 | 
            +
                )
         | 
    	
        app_batched.py
    CHANGED
    
    | @@ -57,7 +57,9 @@ def predict(texts, melodies): | |
| 57 | 
             
                out_files = []
         | 
| 58 | 
             
                for output in outputs:
         | 
| 59 | 
             
                    with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 60 | 
            -
                        audio_write( | 
|  | |
|  | |
| 61 | 
             
                        waveform_video = gr.make_waveform(file.name)
         | 
| 62 | 
             
                        out_files.append(waveform_video)
         | 
| 63 | 
             
                return [out_files]
         | 
|  | |
| 57 | 
             
                out_files = []
         | 
| 58 | 
             
                for output in outputs:
         | 
| 59 | 
             
                    with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 60 | 
            +
                        audio_write(
         | 
| 61 | 
            +
                            file.name, output, MODEL.sample_rate, strategy="loudness",
         | 
| 62 | 
            +
                            loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
         | 
| 63 | 
             
                        waveform_video = gr.make_waveform(file.name)
         | 
| 64 | 
             
                        out_files.append(waveform_video)
         | 
| 65 | 
             
                return [out_files]
         | 
    	
        audiocraft/__init__.py
    CHANGED
    
    | @@ -7,4 +7,4 @@ | |
| 7 | 
             
            # flake8: noqa
         | 
| 8 | 
             
            from . import data, modules, models
         | 
| 9 |  | 
| 10 | 
            -
            __version__ = '0.0. | 
|  | |
| 7 | 
             
            # flake8: noqa
         | 
| 8 | 
             
            from . import data, modules, models
         | 
| 9 |  | 
| 10 | 
            +
            __version__ = '0.0.2a1'
         | 
    	
        audiocraft/data/audio.py
    CHANGED
    
    | @@ -155,6 +155,7 @@ def audio_write(stem_name: tp.Union[str, Path], | |
| 155 | 
             
                            format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
         | 
| 156 | 
             
                            strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 157 | 
             
                            rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
|  | |
| 158 | 
             
                            log_clipping: bool = True, make_parent_dir: bool = True,
         | 
| 159 | 
             
                            add_suffix: bool = True) -> Path:
         | 
| 160 | 
             
                """Convenience function for saving audio to disk. Returns the filename the audio was written to.
         | 
| @@ -173,7 +174,8 @@ def audio_write(stem_name: tp.Union[str, Path], | |
| 173 | 
             
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 174 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 175 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 176 | 
            -
                     | 
|  | |
| 177 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 178 | 
             
                    make_parent_dir (bool): Make parent directory if it doesn't exist.
         | 
| 179 | 
             
                Returns:
         | 
|  | |
| 155 | 
             
                            format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
         | 
| 156 | 
             
                            strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 157 | 
             
                            rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 158 | 
            +
                            loudness_compressor: bool = False,
         | 
| 159 | 
             
                            log_clipping: bool = True, make_parent_dir: bool = True,
         | 
| 160 | 
             
                            add_suffix: bool = True) -> Path:
         | 
| 161 | 
             
                """Convenience function for saving audio to disk. Returns the filename the audio was written to.
         | 
|  | |
| 174 | 
             
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 175 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 176 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 177 | 
            +
                    loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
         | 
| 178 | 
            +
                     when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 179 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 180 | 
             
                    make_parent_dir (bool): Make parent directory if it doesn't exist.
         | 
| 181 | 
             
                Returns:
         | 
    	
        audiocraft/data/audio_utils.py
    CHANGED
    
    | @@ -54,8 +54,8 @@ def convert_audio(wav: torch.Tensor, from_rate: float, | |
| 54 | 
             
                return wav
         | 
| 55 |  | 
| 56 |  | 
| 57 | 
            -
            def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float =  | 
| 58 | 
            -
                                   energy_floor: float = 2e-3):
         | 
| 59 | 
             
                """Normalize an input signal to a user loudness in dB LKFS.
         | 
| 60 | 
             
                Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
         | 
| 61 |  | 
| @@ -63,6 +63,7 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db | |
| 63 | 
             
                    wav (torch.Tensor): Input multichannel audio data.
         | 
| 64 | 
             
                    sample_rate (int): Sample rate.
         | 
| 65 | 
             
                    loudness_headroom_db (float): Target loudness of the output in dB LUFS.
         | 
|  | |
| 66 | 
             
                    energy_floor (float): anything below that RMS level will not be rescaled.
         | 
| 67 | 
             
                Returns:
         | 
| 68 | 
             
                    output (torch.Tensor): Loudness normalized output data.
         | 
| @@ -76,6 +77,8 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db | |
| 76 | 
             
                delta_loudness = -loudness_headroom_db - input_loudness_db
         | 
| 77 | 
             
                gain = 10.0 ** (delta_loudness / 20.0)
         | 
| 78 | 
             
                output = gain * wav
         | 
|  | |
|  | |
| 79 | 
             
                assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
         | 
| 80 | 
             
                return output
         | 
| 81 |  | 
| @@ -93,7 +96,8 @@ def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optio | |
| 93 | 
             
            def normalize_audio(wav: torch.Tensor, normalize: bool = True,
         | 
| 94 | 
             
                                strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 95 | 
             
                                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 96 | 
            -
                                 | 
|  | |
| 97 | 
             
                                stem_name: tp.Optional[str] = None) -> torch.Tensor:
         | 
| 98 | 
             
                """Normalize the audio according to the prescribed strategy (see after).
         | 
| 99 |  | 
| @@ -109,6 +113,7 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True, | |
| 109 | 
             
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 110 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 111 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
|  | |
| 112 | 
             
                    log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 113 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 114 | 
             
                    sample_rate (int): Sample rate for the audio data (required for loudness).
         | 
| @@ -132,7 +137,7 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True, | |
| 132 | 
             
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 133 | 
             
                elif strategy == 'loudness':
         | 
| 134 | 
             
                    assert sample_rate is not None, "Loudness normalization requires sample rate."
         | 
| 135 | 
            -
                    wav = normalize_loudness(wav, sample_rate, loudness_headroom_db)
         | 
| 136 | 
             
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 137 | 
             
                else:
         | 
| 138 | 
             
                    assert wav.abs().max() < 1
         | 
|  | |
| 54 | 
             
                return wav
         | 
| 55 |  | 
| 56 |  | 
| 57 | 
            +
            def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
         | 
| 58 | 
            +
                                   loudness_compressor: bool = False, energy_floor: float = 2e-3):
         | 
| 59 | 
             
                """Normalize an input signal to a user loudness in dB LKFS.
         | 
| 60 | 
             
                Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
         | 
| 61 |  | 
|  | |
| 63 | 
             
                    wav (torch.Tensor): Input multichannel audio data.
         | 
| 64 | 
             
                    sample_rate (int): Sample rate.
         | 
| 65 | 
             
                    loudness_headroom_db (float): Target loudness of the output in dB LUFS.
         | 
| 66 | 
            +
                    loudness_compressor (bool): Uses tanh for soft clipping.
         | 
| 67 | 
             
                    energy_floor (float): anything below that RMS level will not be rescaled.
         | 
| 68 | 
             
                Returns:
         | 
| 69 | 
             
                    output (torch.Tensor): Loudness normalized output data.
         | 
|  | |
| 77 | 
             
                delta_loudness = -loudness_headroom_db - input_loudness_db
         | 
| 78 | 
             
                gain = 10.0 ** (delta_loudness / 20.0)
         | 
| 79 | 
             
                output = gain * wav
         | 
| 80 | 
            +
                if loudness_compressor:
         | 
| 81 | 
            +
                    output = torch.tanh(output)
         | 
| 82 | 
             
                assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
         | 
| 83 | 
             
                return output
         | 
| 84 |  | 
|  | |
| 96 | 
             
            def normalize_audio(wav: torch.Tensor, normalize: bool = True,
         | 
| 97 | 
             
                                strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 98 | 
             
                                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 99 | 
            +
                                loudness_compressor: bool = False, log_clipping: bool = False,
         | 
| 100 | 
            +
                                sample_rate: tp.Optional[int] = None,
         | 
| 101 | 
             
                                stem_name: tp.Optional[str] = None) -> torch.Tensor:
         | 
| 102 | 
             
                """Normalize the audio according to the prescribed strategy (see after).
         | 
| 103 |  | 
|  | |
| 113 | 
             
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 114 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 115 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 116 | 
            +
                    loudness_compressor (bool): If True, uses tanh based soft clipping.
         | 
| 117 | 
             
                    log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 118 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 119 | 
             
                    sample_rate (int): Sample rate for the audio data (required for loudness).
         | 
|  | |
| 137 | 
             
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 138 | 
             
                elif strategy == 'loudness':
         | 
| 139 | 
             
                    assert sample_rate is not None, "Loudness normalization requires sample rate."
         | 
| 140 | 
            +
                    wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
         | 
| 141 | 
             
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 142 | 
             
                else:
         | 
| 143 | 
             
                    assert wav.abs().max() < 1
         | 
    	
        audiocraft/models/musicgen.py
    CHANGED
    
    | @@ -88,6 +88,8 @@ class MusicGen: | |
| 88 | 
             
                    cache_dir = os.environ.get('MUSICGEN_ROOT', None)
         | 
| 89 | 
             
                    compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
         | 
| 90 | 
             
                    lm = load_lm_model(name, device=device, cache_dir=cache_dir)
         | 
|  | |
|  | |
| 91 |  | 
| 92 | 
             
                    return MusicGen(name, compression_model, lm)
         | 
| 93 |  | 
|  | |
| 88 | 
             
                    cache_dir = os.environ.get('MUSICGEN_ROOT', None)
         | 
| 89 | 
             
                    compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
         | 
| 90 | 
             
                    lm = load_lm_model(name, device=device, cache_dir=cache_dir)
         | 
| 91 | 
            +
                    if name == 'melody':
         | 
| 92 | 
            +
                        lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
         | 
| 93 |  | 
| 94 | 
             
                    return MusicGen(name, compression_model, lm)
         | 
| 95 |  | 
    	
        audiocraft/modules/conditioners.py
    CHANGED
    
    | @@ -9,6 +9,7 @@ from copy import deepcopy | |
| 9 | 
             
            from dataclasses import dataclass, field
         | 
| 10 | 
             
            from itertools import chain
         | 
| 11 | 
             
            import logging
         | 
|  | |
| 12 | 
             
            import random
         | 
| 13 | 
             
            import re
         | 
| 14 | 
             
            import typing as tp
         | 
| @@ -484,7 +485,7 @@ class ChromaStemConditioner(WaveformConditioner): | |
| 484 | 
             
                    **kwargs: Additional parameters for the chroma extractor.
         | 
| 485 | 
             
                """
         | 
| 486 | 
             
                def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
         | 
| 487 | 
            -
                             duration: float, match_len_on_eval: bool =  | 
| 488 | 
             
                             n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
         | 
| 489 | 
             
                    from demucs import pretrained
         | 
| 490 | 
             
                    super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
         | 
| @@ -535,7 +536,10 @@ class ChromaStemConditioner(WaveformConditioner): | |
| 535 | 
             
                            chroma = chroma[:, :self.chroma_len]
         | 
| 536 | 
             
                            logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
         | 
| 537 | 
             
                        elif t < self.chroma_len:
         | 
| 538 | 
            -
                            chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
         | 
|  | |
|  | |
|  | |
| 539 | 
             
                            logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
         | 
| 540 | 
             
                    return chroma
         | 
| 541 |  | 
|  | |
| 9 | 
             
            from dataclasses import dataclass, field
         | 
| 10 | 
             
            from itertools import chain
         | 
| 11 | 
             
            import logging
         | 
| 12 | 
            +
            import math
         | 
| 13 | 
             
            import random
         | 
| 14 | 
             
            import re
         | 
| 15 | 
             
            import typing as tp
         | 
|  | |
| 485 | 
             
                    **kwargs: Additional parameters for the chroma extractor.
         | 
| 486 | 
             
                """
         | 
| 487 | 
             
                def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
         | 
| 488 | 
            +
                             duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
         | 
| 489 | 
             
                             n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
         | 
| 490 | 
             
                    from demucs import pretrained
         | 
| 491 | 
             
                    super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
         | 
|  | |
| 536 | 
             
                            chroma = chroma[:, :self.chroma_len]
         | 
| 537 | 
             
                            logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
         | 
| 538 | 
             
                        elif t < self.chroma_len:
         | 
| 539 | 
            +
                            # chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
         | 
| 540 | 
            +
                            n_repeat = int(math.ceil(self.chroma_len / t))
         | 
| 541 | 
            +
                            chroma = chroma.repeat(1, n_repeat, 1)
         | 
| 542 | 
            +
                            chroma = chroma[:, :self.chroma_len]
         | 
| 543 | 
             
                            logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
         | 
| 544 | 
             
                    return chroma
         | 
| 545 |  | 
    	
        requirements.txt
    CHANGED
    
    | @@ -17,3 +17,4 @@ transformers | |
| 17 | 
             
            xformers
         | 
| 18 | 
             
            demucs
         | 
| 19 | 
             
            librosa
         | 
|  | 
|  | |
| 17 | 
             
            xformers
         | 
| 18 | 
             
            demucs
         | 
| 19 | 
             
            librosa
         | 
| 20 | 
            +
            gradio
         | 
 
			
